334 lines
12 KiB
Python
334 lines
12 KiB
Python
# =============================================================================
|
||
# 企微IT智能服务台 — 测试配置与公共 fixtures
|
||
# =============================================================================
|
||
# 说明:pytest 的全局 fixtures,包括:
|
||
# 1. SQLite 内存数据库(替代 PostgreSQL)
|
||
# 2. 模拟 Redis 客户端
|
||
# 3. FastAPI 测试客户端
|
||
# 4. 测试用数据库会话
|
||
# =============================================================================
|
||
|
||
import asyncio
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import AsyncGenerator, Dict, Optional
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from httpx import ASGITransport, AsyncClient
|
||
from sqlalchemy import event
|
||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.models.agent import Agent
|
||
from app.models.conversation import Conversation
|
||
from app.models.message import Message
|
||
from app.models.system_config import SystemConfig
|
||
from app.models.funny_phrase import FunnyPhrase
|
||
from app.models.approval_link import ApprovalLink
|
||
from app.models.software_download import SoftwareDownload
|
||
from app.models.quick_reply_template import QuickReplyTemplate
|
||
from app.models.agent_note import AgentNote
|
||
|
||
|
||
# =============================================================================
|
||
# SQLite 内存数据库引擎
|
||
# =============================================================================
|
||
# 使用 aiosqlite 驱动的 SQLite 内存数据库替代 PostgreSQL
|
||
# StaticPool 确保所有连接使用同一个内存数据库实例
|
||
# =============================================================================
|
||
|
||
TEST_DATABASE_URL = "sqlite+aiosqlite://"
|
||
|
||
test_engine = create_async_engine(
|
||
TEST_DATABASE_URL,
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
|
||
test_session_factory = async_sessionmaker(
|
||
test_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
)
|
||
|
||
|
||
# 为 SQLite 启用外键约束
|
||
@event.listens_for(test_engine.sync_engine, "connect")
|
||
def _set_sqlite_pragma(dbapi_connection, connection_record):
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("PRAGMA foreign_keys=ON")
|
||
cursor.close()
|
||
|
||
|
||
# =============================================================================
|
||
# 模拟 Redis 客户端
|
||
# =============================================================================
|
||
|
||
class MockRedis:
|
||
"""模拟 Redis 客户端,使用内存字典存储数据。"""
|
||
|
||
def __init__(self):
|
||
self._data: Dict[str, str] = {}
|
||
self._ttl: Dict[str, int] = {}
|
||
|
||
async def get(self, key: str) -> Optional[bytes]:
|
||
value = self._data.get(key)
|
||
if value is not None:
|
||
return value.encode("utf-8") if isinstance(value, str) else value
|
||
return None
|
||
|
||
async def setex(self, name: str, time: int, value: str) -> None:
|
||
self._data[name] = value
|
||
self._ttl[name] = time
|
||
|
||
async def set(self, name: str, value: str, **kwargs) -> Optional[bool]:
|
||
"""模拟 Redis SET 命令,支持 nx 和 ex 参数。
|
||
|
||
Args:
|
||
name: Redis key
|
||
value: Redis value
|
||
**kwargs:
|
||
nx: SET IF NOT EXISTS — key 不存在时才设置,返回 True;已存在返回 None
|
||
ex: 过期时间(秒)
|
||
|
||
Returns:
|
||
nx=True 时:True=设置成功,None=key 已存在未设置
|
||
其他情况:None(与真实 Redis SET 行为一致)
|
||
"""
|
||
nx = kwargs.get("nx", False)
|
||
ex = kwargs.get("ex", None)
|
||
|
||
if nx:
|
||
if name in self._data:
|
||
return None # key 已存在,SET NX 未设置
|
||
self._data[name] = value
|
||
if ex is not None:
|
||
self._ttl[name] = ex
|
||
return True # 设置成功
|
||
|
||
self._data[name] = value
|
||
if ex is not None:
|
||
self._ttl[name] = ex
|
||
return None
|
||
|
||
async def delete(self, *names) -> int:
|
||
count = 0
|
||
for name in names:
|
||
if name in self._data:
|
||
del self._data[name]
|
||
count += 1
|
||
return count
|
||
|
||
async def exists(self, *keys) -> int:
|
||
return sum(1 for k in keys if k in self._data)
|
||
|
||
async def expire(self, name: str, time: int) -> bool:
|
||
if name in self._data:
|
||
self._ttl[name] = time
|
||
return True
|
||
return False
|
||
|
||
async def close(self) -> None:
|
||
pass
|
||
|
||
def reset(self) -> None:
|
||
self._data.clear()
|
||
self._ttl.clear()
|
||
|
||
|
||
# =============================================================================
|
||
# Fixtures
|
||
# =============================================================================
|
||
|
||
|
||
@pytest.fixture(scope="session")
|
||
def event_loop():
|
||
"""创建 session 级别的事件循环。"""
|
||
loop = asyncio.new_event_loop()
|
||
yield loop
|
||
loop.close()
|
||
|
||
|
||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||
async def setup_database():
|
||
"""创建所有数据库表(session 级别,只执行一次)。"""
|
||
async with test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
yield
|
||
async with test_engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.drop_all)
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||
"""提供干净的数据库会话,每个测试用例使用独立事务并在测试后回滚。"""
|
||
async with test_session_factory() as session:
|
||
# 开始一个嵌套事务
|
||
nested = await session.begin_nested()
|
||
try:
|
||
yield session
|
||
finally:
|
||
# 回滚嵌套事务,确保数据库干净
|
||
if nested.is_active:
|
||
await nested.rollback()
|
||
# 清理会话
|
||
await session.close()
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_redis() -> MockRedis:
|
||
"""提供模拟 Redis 客户端。"""
|
||
return MockRedis()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncGenerator[AsyncClient, None]:
|
||
"""提供 FastAPI 异步测试客户端。"""
|
||
|
||
async def _override_get_db():
|
||
yield db_session
|
||
|
||
async def _override_get_redis():
|
||
return mock_redis
|
||
|
||
from app.main import create_app
|
||
from app.database import get_db
|
||
|
||
app = create_app()
|
||
|
||
# 覆盖数据库依赖
|
||
app.dependency_overrides[get_db] = _override_get_db
|
||
|
||
# 模拟 Redis(同时 mock agents 和 h5 模块的 Redis 依赖)
|
||
with patch("app.api.agents._get_redis", return_value=mock_redis):
|
||
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
||
# ------------------------------------------------------------------
|
||
# Mock 外部服务:WecomService(企微API)和 AIService(AI大模型)
|
||
# 为什么:测试中不应调用真实企微API/AI大模型
|
||
# 怎么做:patch 类构造函数,返回配置了默认返回值的 mock 对象
|
||
# ------------------------------------------------------------------
|
||
mock_wecom = AsyncMock()
|
||
# 企微消息发送:默认成功
|
||
mock_wecom.send_message.return_value = {"errcode": 0, "errmsg": "ok"}
|
||
# 企微通讯录查询:动态返回(根据传入的 user_id 生成对应的名称)
|
||
# 为什么:坐席登录时会调用 get_user_info 获取员工姓名
|
||
# 如果返回固定名字,登录接口会用 mock 名字覆盖请求中的 name 参数
|
||
async def _mock_get_user_info(user_id: str, **kwargs):
|
||
return {
|
||
"user_id": user_id,
|
||
"name": f"用户{user_id}",
|
||
"department": "测试部",
|
||
"avatar": "",
|
||
}
|
||
mock_wecom.get_user_info.side_effect = _mock_get_user_info
|
||
mock_wecom.get_department_users.return_value = []
|
||
|
||
mock_ai = AsyncMock()
|
||
mock_ai.generate_response.return_value = "这是AI的模拟回复"
|
||
|
||
# Patch WecomService 类(端点函数中会新建实例)
|
||
# 注意:只 patch 模块中实际引用的名字
|
||
# conversations.py 导入了 WecomService,但没有导入 AIService
|
||
with patch("app.api.conversations.WecomService", return_value=mock_wecom):
|
||
# h5.py 和 agents.py 也需要 patch
|
||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||
with patch("app.api.agents.WecomService", return_value=mock_wecom):
|
||
with patch("app.api.agents._get_redis", return_value=mock_redis):
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||
yield ac
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def seeded_db(db_session: AsyncSession) -> AsyncSession:
|
||
"""插入测试基础数据并返回会话。"""
|
||
# 系统配置
|
||
configs = [
|
||
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","人工服务","真人","客服"]', description="举手关键词"),
|
||
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉","差劲","垃圾"]', description="愤怒关键词"),
|
||
SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上","立刻","赶紧"]', description="紧急关键词"),
|
||
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕","出错","丢失","完蛋"]', description="担忧关键词"),
|
||
SystemConfig(config_key="intervene_round_threshold", config_value="3", description="介入阈值"),
|
||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="基础加分"),
|
||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪加成"),
|
||
SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成"),
|
||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复加成"),
|
||
]
|
||
db_session.add_all(configs)
|
||
|
||
# 趣味话术
|
||
phrases = [
|
||
FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1),
|
||
FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1),
|
||
]
|
||
db_session.add_all(phrases)
|
||
|
||
# 审批链接
|
||
links = [
|
||
ApprovalLink(category="IT", title="软件安装申请", url="https://example.com/software", sort_order=1),
|
||
ApprovalLink(category="HR", title="入职手续", url="https://example.com/onboarding", sort_order=2),
|
||
]
|
||
db_session.add_all(links)
|
||
|
||
# 软件下载
|
||
downloads = [
|
||
SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com", sort_order=1),
|
||
SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com", sort_order=2),
|
||
]
|
||
db_session.add_all(downloads)
|
||
|
||
await db_session.flush()
|
||
return db_session
|
||
|
||
|
||
# =============================================================================
|
||
# 辅助函数
|
||
# =============================================================================
|
||
|
||
def create_test_conversation(
|
||
employee_id: str = "test_employee_001",
|
||
employee_name: str = "测试员工",
|
||
status: str = "queued",
|
||
is_vip: bool = False,
|
||
is_pinned: bool = False,
|
||
is_todo: bool = False,
|
||
urgency_score: int = 1,
|
||
tags: Optional[Dict] = None,
|
||
) -> Conversation:
|
||
"""创建测试用的会话对象。"""
|
||
return Conversation(
|
||
employee_id=employee_id,
|
||
employee_name=employee_name,
|
||
department="技术部",
|
||
position="工程师",
|
||
level="",
|
||
status=status,
|
||
is_vip=is_vip,
|
||
is_pinned=is_pinned,
|
||
is_todo=is_todo,
|
||
urgency_score=urgency_score,
|
||
tags=tags or {},
|
||
last_message_at=datetime.now(),
|
||
last_message_summary="测试消息",
|
||
)
|
||
|
||
|
||
def create_test_agent(
|
||
user_id: str = "test_agent_001",
|
||
name: str = "测试坐席",
|
||
status: str = "online",
|
||
) -> Agent:
|
||
"""创建测试用的坐席对象。"""
|
||
return Agent(
|
||
user_id=user_id,
|
||
name=name,
|
||
status=status,
|
||
current_load=0,
|
||
max_load=5,
|
||
)
|