Files

334 lines
12 KiB
Python
Raw Permalink Normal View History

# =============================================================================
# 企微IT智能服务台 — 测试配置与公共 fixtures
# =============================================================================
# 说明:pytest 的全局 fixtures,包括:
# 1. SQLite 内存数据库(替代 PostgreSQL
# 2. 模拟 Redis 客户端
# 3. FastAPI 测试客户端
# 4. 测试用数据库会话
# =============================================================================
import asyncio
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.models.system_config import SystemConfig
from app.models.funny_phrase import FunnyPhrase
from app.models.approval_link import ApprovalLink
from app.models.software_download import SoftwareDownload
from app.models.quick_reply_template import QuickReplyTemplate
from app.models.agent_note import AgentNote
# =============================================================================
# SQLite 内存数据库引擎
# =============================================================================
# 使用 aiosqlite 驱动的 SQLite 内存数据库替代 PostgreSQL
# StaticPool 确保所有连接使用同一个内存数据库实例
# =============================================================================
TEST_DATABASE_URL = "sqlite+aiosqlite://"
test_engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
test_session_factory = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
# 为 SQLite 启用外键约束
@event.listens_for(test_engine.sync_engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# =============================================================================
# 模拟 Redis 客户端
# =============================================================================
class MockRedis:
"""模拟 Redis 客户端,使用内存字典存储数据。"""
def __init__(self):
self._data: Dict[str, str] = {}
self._ttl: Dict[str, int] = {}
async def get(self, key: str) -> Optional[bytes]:
value = self._data.get(key)
if value is not None:
return value.encode("utf-8") if isinstance(value, str) else value
return None
async def setex(self, name: str, time: int, value: str) -> None:
self._data[name] = value
self._ttl[name] = time
async def set(self, name: str, value: str, **kwargs) -> Optional[bool]:
"""模拟 Redis SET 命令,支持 nx 和 ex 参数。
Args:
name: Redis key
value: Redis value
**kwargs:
nx: SET IF NOT EXISTS — key 不存在时才设置,返回 True;已存在返回 None
ex: 过期时间(秒)
Returns:
nx=True 时:True=设置成功,None=key 已存在未设置
其他情况:None(与真实 Redis SET 行为一致)
"""
nx = kwargs.get("nx", False)
ex = kwargs.get("ex", None)
if nx:
if name in self._data:
return None # key 已存在,SET NX 未设置
self._data[name] = value
if ex is not None:
self._ttl[name] = ex
return True # 设置成功
self._data[name] = value
if ex is not None:
self._ttl[name] = ex
return None
async def delete(self, *names) -> int:
count = 0
for name in names:
if name in self._data:
del self._data[name]
count += 1
return count
async def exists(self, *keys) -> int:
return sum(1 for k in keys if k in self._data)
async def expire(self, name: str, time: int) -> bool:
if name in self._data:
self._ttl[name] = time
return True
return False
async def close(self) -> None:
pass
def reset(self) -> None:
self._data.clear()
self._ttl.clear()
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def event_loop():
"""创建 session 级别的事件循环。"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database():
"""创建所有数据库表(session 级别,只执行一次)。"""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""提供干净的数据库会话,每个测试用例使用独立事务并在测试后回滚。"""
async with test_session_factory() as session:
# 开始一个嵌套事务
nested = await session.begin_nested()
try:
yield session
finally:
# 回滚嵌套事务,确保数据库干净
if nested.is_active:
await nested.rollback()
# 清理会话
await session.close()
@pytest.fixture
def mock_redis() -> MockRedis:
"""提供模拟 Redis 客户端。"""
return MockRedis()
@pytest_asyncio.fixture
async def client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncGenerator[AsyncClient, None]:
"""提供 FastAPI 异步测试客户端。"""
async def _override_get_db():
yield db_session
async def _override_get_redis():
return mock_redis
from app.main import create_app
from app.database import get_db
app = create_app()
# 覆盖数据库依赖
app.dependency_overrides[get_db] = _override_get_db
# 模拟 Redis(同时 mock agents 和 h5 模块的 Redis 依赖)
with patch("app.api.agents._get_redis", return_value=mock_redis):
with patch("redis.asyncio.from_url", return_value=mock_redis):
# ------------------------------------------------------------------
# Mock 外部服务:WecomService(企微API)和 AIServiceAI大模型)
# 为什么:测试中不应调用真实企微API/AI大模型
# 怎么做:patch 类构造函数,返回配置了默认返回值的 mock 对象
# ------------------------------------------------------------------
mock_wecom = AsyncMock()
# 企微消息发送:默认成功
mock_wecom.send_message.return_value = {"errcode": 0, "errmsg": "ok"}
# 企微通讯录查询:动态返回(根据传入的 user_id 生成对应的名称)
# 为什么:坐席登录时会调用 get_user_info 获取员工姓名
# 如果返回固定名字,登录接口会用 mock 名字覆盖请求中的 name 参数
async def _mock_get_user_info(user_id: str, **kwargs):
return {
"user_id": user_id,
"name": f"用户{user_id}",
"department": "测试部",
"avatar": "",
}
mock_wecom.get_user_info.side_effect = _mock_get_user_info
mock_wecom.get_department_users.return_value = []
mock_ai = AsyncMock()
mock_ai.generate_response.return_value = "这是AI的模拟回复"
# Patch WecomService 类(端点函数中会新建实例)
# 注意:只 patch 模块中实际引用的名字
# conversations.py 导入了 WecomService,但没有导入 AIService
with patch("app.api.conversations.WecomService", return_value=mock_wecom):
# h5.py 和 agents.py 也需要 patch
with patch("app.api.h5.WecomService", return_value=mock_wecom):
with patch("app.api.agents.WecomService", return_value=mock_wecom):
with patch("app.api.agents._get_redis", return_value=mock_redis):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def seeded_db(db_session: AsyncSession) -> AsyncSession:
"""插入测试基础数据并返回会话。"""
# 系统配置
configs = [
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","人工服务","真人","客服"]', description="举手关键词"),
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉","差劲","垃圾"]', description="愤怒关键词"),
SystemConfig(config_key="emotion_keywords_urgent", config_value='["","紧急","马上","立刻","赶紧"]', description="紧急关键词"),
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕","出错","丢失","完蛋"]', description="担忧关键词"),
SystemConfig(config_key="intervene_round_threshold", config_value="3", description="介入阈值"),
SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="基础加分"),
SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪加成"),
SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成"),
SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复加成"),
]
db_session.add_all(configs)
# 趣味话术
phrases = [
FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1),
FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1),
]
db_session.add_all(phrases)
# 审批链接
links = [
ApprovalLink(category="IT", title="软件安装申请", url="https://example.com/software", sort_order=1),
ApprovalLink(category="HR", title="入职手续", url="https://example.com/onboarding", sort_order=2),
]
db_session.add_all(links)
# 软件下载
downloads = [
SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com", sort_order=1),
SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com", sort_order=2),
]
db_session.add_all(downloads)
await db_session.flush()
return db_session
# =============================================================================
# 辅助函数
# =============================================================================
def create_test_conversation(
employee_id: str = "test_employee_001",
employee_name: str = "测试员工",
status: str = "queued",
is_vip: bool = False,
is_pinned: bool = False,
is_todo: bool = False,
urgency_score: int = 1,
tags: Optional[Dict] = None,
) -> Conversation:
"""创建测试用的会话对象。"""
return Conversation(
employee_id=employee_id,
employee_name=employee_name,
department="技术部",
position="工程师",
level="",
status=status,
is_vip=is_vip,
is_pinned=is_pinned,
is_todo=is_todo,
urgency_score=urgency_score,
tags=tags or {},
last_message_at=datetime.now(),
last_message_summary="测试消息",
)
def create_test_agent(
user_id: str = "test_agent_001",
name: str = "测试坐席",
status: str = "online",
) -> Agent:
"""创建测试用的坐席对象。"""
return Agent(
user_id=user_id,
name=name,
status=status,
current_load=0,
max_load=5,
)