# ============================================================================= # 企微IT智能服务台 — 测试配置与公共 fixtures # ============================================================================= # 说明:pytest 的全局 fixtures,包括: # 1. SQLite 内存数据库(替代 PostgreSQL) # 2. 模拟 Redis 客户端 # 3. FastAPI 测试客户端 # 4. 测试用数据库会话 # ============================================================================= import asyncio import uuid from datetime import datetime from typing import AsyncGenerator, Dict, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy import event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.models.agent import Agent from app.models.conversation import Conversation from app.models.message import Message from app.models.system_config import SystemConfig from app.models.funny_phrase import FunnyPhrase from app.models.approval_link import ApprovalLink from app.models.software_download import SoftwareDownload from app.models.quick_reply_template import QuickReplyTemplate from app.models.agent_note import AgentNote # ============================================================================= # SQLite 内存数据库引擎 # ============================================================================= # 使用 aiosqlite 驱动的 SQLite 内存数据库替代 PostgreSQL # StaticPool 确保所有连接使用同一个内存数据库实例 # ============================================================================= TEST_DATABASE_URL = "sqlite+aiosqlite://" test_engine = create_async_engine( TEST_DATABASE_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) test_session_factory = async_sessionmaker( test_engine, class_=AsyncSession, expire_on_commit=False, ) # 为 SQLite 启用外键约束 @event.listens_for(test_engine.sync_engine, "connect") def _set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() # ============================================================================= # 模拟 Redis 客户端 # ============================================================================= class MockRedis: """模拟 Redis 客户端,使用内存字典存储数据。""" def __init__(self): self._data: Dict[str, str] = {} self._ttl: Dict[str, int] = {} async def get(self, key: str) -> Optional[bytes]: value = self._data.get(key) if value is not None: return value.encode("utf-8") if isinstance(value, str) else value return None async def setex(self, name: str, time: int, value: str) -> None: self._data[name] = value self._ttl[name] = time async def set(self, name: str, value: str, **kwargs) -> Optional[bool]: """模拟 Redis SET 命令,支持 nx 和 ex 参数。 Args: name: Redis key value: Redis value **kwargs: nx: SET IF NOT EXISTS — key 不存在时才设置,返回 True;已存在返回 None ex: 过期时间(秒) Returns: nx=True 时:True=设置成功,None=key 已存在未设置 其他情况:None(与真实 Redis SET 行为一致) """ nx = kwargs.get("nx", False) ex = kwargs.get("ex", None) if nx: if name in self._data: return None # key 已存在,SET NX 未设置 self._data[name] = value if ex is not None: self._ttl[name] = ex return True # 设置成功 self._data[name] = value if ex is not None: self._ttl[name] = ex return None async def delete(self, *names) -> int: count = 0 for name in names: if name in self._data: del self._data[name] count += 1 return count async def exists(self, *keys) -> int: return sum(1 for k in keys if k in self._data) async def expire(self, name: str, time: int) -> bool: if name in self._data: self._ttl[name] = time return True return False async def close(self) -> None: pass def reset(self) -> None: self._data.clear() self._ttl.clear() # ============================================================================= # Fixtures # ============================================================================= @pytest.fixture(scope="session") def event_loop(): """创建 session 级别的事件循环。""" loop = asyncio.new_event_loop() yield loop loop.close() @pytest_asyncio.fixture(scope="session", autouse=True) async def setup_database(): """创建所有数据库表(session 级别,只执行一次)。""" async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with test_engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest_asyncio.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """提供干净的数据库会话,每个测试用例使用独立事务并在测试后回滚。""" async with test_session_factory() as session: # 开始一个嵌套事务 nested = await session.begin_nested() try: yield session finally: # 回滚嵌套事务,确保数据库干净 if nested.is_active: await nested.rollback() # 清理会话 await session.close() @pytest.fixture def mock_redis() -> MockRedis: """提供模拟 Redis 客户端。""" return MockRedis() @pytest_asyncio.fixture async def client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncGenerator[AsyncClient, None]: """提供 FastAPI 异步测试客户端。""" async def _override_get_db(): yield db_session async def _override_get_redis(): return mock_redis from app.main import create_app from app.database import get_db app = create_app() # 覆盖数据库依赖 app.dependency_overrides[get_db] = _override_get_db # 模拟 Redis(同时 mock agents 和 h5 模块的 Redis 依赖) with patch("app.api.agents._get_redis", return_value=mock_redis): with patch("redis.asyncio.from_url", return_value=mock_redis): # ------------------------------------------------------------------ # Mock 外部服务:WecomService(企微API)和 AIService(AI大模型) # 为什么:测试中不应调用真实企微API/AI大模型 # 怎么做:patch 类构造函数,返回配置了默认返回值的 mock 对象 # ------------------------------------------------------------------ mock_wecom = AsyncMock() # 企微消息发送:默认成功 mock_wecom.send_message.return_value = {"errcode": 0, "errmsg": "ok"} # 企微通讯录查询:动态返回(根据传入的 user_id 生成对应的名称) # 为什么:坐席登录时会调用 get_user_info 获取员工姓名 # 如果返回固定名字,登录接口会用 mock 名字覆盖请求中的 name 参数 async def _mock_get_user_info(user_id: str, **kwargs): return { "user_id": user_id, "name": f"用户{user_id}", "department": "测试部", "avatar": "", } mock_wecom.get_user_info.side_effect = _mock_get_user_info mock_wecom.get_department_users.return_value = [] mock_ai = AsyncMock() mock_ai.generate_response.return_value = "这是AI的模拟回复" # Patch WecomService 类(端点函数中会新建实例) # 注意:只 patch 模块中实际引用的名字 # conversations.py 导入了 WecomService,但没有导入 AIService with patch("app.api.conversations.WecomService", return_value=mock_wecom): # h5.py 和 agents.py 也需要 patch with patch("app.api.h5.WecomService", return_value=mock_wecom): with patch("app.api.agents.WecomService", return_value=mock_wecom): with patch("app.api.agents._get_redis", return_value=mock_redis): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac app.dependency_overrides.clear() @pytest_asyncio.fixture async def seeded_db(db_session: AsyncSession) -> AsyncSession: """插入测试基础数据并返回会话。""" # 系统配置 configs = [ SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","人工服务","真人","客服"]', description="举手关键词"), SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉","差劲","垃圾"]', description="愤怒关键词"), SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上","立刻","赶紧"]', description="紧急关键词"), SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕","出错","丢失","完蛋"]', description="担忧关键词"), SystemConfig(config_key="intervene_round_threshold", config_value="3", description="介入阈值"), SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="基础加分"), SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪加成"), SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成"), SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复加成"), ] db_session.add_all(configs) # 趣味话术 phrases = [ FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1), FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1), ] db_session.add_all(phrases) # 审批链接 links = [ ApprovalLink(category="IT", title="软件安装申请", url="https://example.com/software", sort_order=1), ApprovalLink(category="HR", title="入职手续", url="https://example.com/onboarding", sort_order=2), ] db_session.add_all(links) # 软件下载 downloads = [ SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com", sort_order=1), SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com", sort_order=2), ] db_session.add_all(downloads) await db_session.flush() return db_session # ============================================================================= # 辅助函数 # ============================================================================= def create_test_conversation( employee_id: str = "test_employee_001", employee_name: str = "测试员工", status: str = "queued", is_vip: bool = False, is_pinned: bool = False, is_todo: bool = False, urgency_score: int = 1, tags: Optional[Dict] = None, ) -> Conversation: """创建测试用的会话对象。""" return Conversation( employee_id=employee_id, employee_name=employee_name, department="技术部", position="工程师", level="", status=status, is_vip=is_vip, is_pinned=is_pinned, is_todo=is_todo, urgency_score=urgency_score, tags=tags or {}, last_message_at=datetime.now(), last_message_summary="测试消息", ) def create_test_agent( user_id: str = "test_agent_001", name: str = "测试坐席", status: str = "online", ) -> Agent: """创建测试用的坐席对象。""" return Agent( user_id=user_id, name=name, status=status, current_load=0, max_load=5, )