Files
wecom_it_smart_desk/backend/tests/conftest.py
T

334 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — 测试配置与公共 fixtures
# =============================================================================
# 说明:pytest 的全局 fixtures,包括:
# 1. SQLite 内存数据库(替代 PostgreSQL
# 2. 模拟 Redis 客户端
# 3. FastAPI 测试客户端
# 4. 测试用数据库会话
# =============================================================================
import asyncio
import uuid
from datetime import datetime
from typing import AsyncGenerator, Dict, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy import event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.models.system_config import SystemConfig
from app.models.funny_phrase import FunnyPhrase
from app.models.approval_link import ApprovalLink
from app.models.software_download import SoftwareDownload
from app.models.quick_reply_template import QuickReplyTemplate
from app.models.agent_note import AgentNote
# =============================================================================
# SQLite 内存数据库引擎
# =============================================================================
# 使用 aiosqlite 驱动的 SQLite 内存数据库替代 PostgreSQL
# StaticPool 确保所有连接使用同一个内存数据库实例
# =============================================================================
TEST_DATABASE_URL = "sqlite+aiosqlite://"
test_engine = create_async_engine(
TEST_DATABASE_URL,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
test_session_factory = async_sessionmaker(
test_engine,
class_=AsyncSession,
expire_on_commit=False,
)
# 为 SQLite 启用外键约束
@event.listens_for(test_engine.sync_engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# =============================================================================
# 模拟 Redis 客户端
# =============================================================================
class MockRedis:
"""模拟 Redis 客户端,使用内存字典存储数据。"""
def __init__(self):
self._data: Dict[str, str] = {}
self._ttl: Dict[str, int] = {}
async def get(self, key: str) -> Optional[bytes]:
value = self._data.get(key)
if value is not None:
return value.encode("utf-8") if isinstance(value, str) else value
return None
async def setex(self, name: str, time: int, value: str) -> None:
self._data[name] = value
self._ttl[name] = time
async def set(self, name: str, value: str, **kwargs) -> Optional[bool]:
"""模拟 Redis SET 命令,支持 nx 和 ex 参数。
Args:
name: Redis key
value: Redis value
**kwargs:
nx: SET IF NOT EXISTS — key 不存在时才设置,返回 True;已存在返回 None
ex: 过期时间(秒)
Returns:
nx=True 时:True=设置成功,None=key 已存在未设置
其他情况:None(与真实 Redis SET 行为一致)
"""
nx = kwargs.get("nx", False)
ex = kwargs.get("ex", None)
if nx:
if name in self._data:
return None # key 已存在,SET NX 未设置
self._data[name] = value
if ex is not None:
self._ttl[name] = ex
return True # 设置成功
self._data[name] = value
if ex is not None:
self._ttl[name] = ex
return None
async def delete(self, *names) -> int:
count = 0
for name in names:
if name in self._data:
del self._data[name]
count += 1
return count
async def exists(self, *keys) -> int:
return sum(1 for k in keys if k in self._data)
async def expire(self, name: str, time: int) -> bool:
if name in self._data:
self._ttl[name] = time
return True
return False
async def close(self) -> None:
pass
def reset(self) -> None:
self._data.clear()
self._ttl.clear()
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(scope="session")
def event_loop():
"""创建 session 级别的事件循环。"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_database():
"""创建所有数据库表(session 级别,只执行一次)。"""
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""提供干净的数据库会话,每个测试用例使用独立事务并在测试后回滚。"""
async with test_session_factory() as session:
# 开始一个嵌套事务
nested = await session.begin_nested()
try:
yield session
finally:
# 回滚嵌套事务,确保数据库干净
if nested.is_active:
await nested.rollback()
# 清理会话
await session.close()
@pytest.fixture
def mock_redis() -> MockRedis:
"""提供模拟 Redis 客户端。"""
return MockRedis()
@pytest_asyncio.fixture
async def client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncGenerator[AsyncClient, None]:
"""提供 FastAPI 异步测试客户端。"""
async def _override_get_db():
yield db_session
async def _override_get_redis():
return mock_redis
from app.main import create_app
from app.database import get_db
app = create_app()
# 覆盖数据库依赖
app.dependency_overrides[get_db] = _override_get_db
# 模拟 Redis(同时 mock agents 和 h5 模块的 Redis 依赖)
with patch("app.api.agents._get_redis", return_value=mock_redis):
with patch("redis.asyncio.from_url", return_value=mock_redis):
# ------------------------------------------------------------------
# Mock 外部服务:WecomService(企微API)和 AIServiceAI大模型)
# 为什么:测试中不应调用真实企微API/AI大模型
# 怎么做:patch 类构造函数,返回配置了默认返回值的 mock 对象
# ------------------------------------------------------------------
mock_wecom = AsyncMock()
# 企微消息发送:默认成功
mock_wecom.send_message.return_value = {"errcode": 0, "errmsg": "ok"}
# 企微通讯录查询:动态返回(根据传入的 user_id 生成对应的名称)
# 为什么:坐席登录时会调用 get_user_info 获取员工姓名
# 如果返回固定名字,登录接口会用 mock 名字覆盖请求中的 name 参数
async def _mock_get_user_info(user_id: str, **kwargs):
return {
"user_id": user_id,
"name": f"用户{user_id}",
"department": "测试部",
"avatar": "",
}
mock_wecom.get_user_info.side_effect = _mock_get_user_info
mock_wecom.get_department_users.return_value = []
mock_ai = AsyncMock()
mock_ai.generate_response.return_value = "这是AI的模拟回复"
# Patch WecomService 类(端点函数中会新建实例)
# 注意:只 patch 模块中实际引用的名字
# conversations.py 导入了 WecomService,但没有导入 AIService
with patch("app.api.conversations.WecomService", return_value=mock_wecom):
# h5.py 和 agents.py 也需要 patch
with patch("app.api.h5.WecomService", return_value=mock_wecom):
with patch("app.api.agents.WecomService", return_value=mock_wecom):
with patch("app.api.agents._get_redis", return_value=mock_redis):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def seeded_db(db_session: AsyncSession) -> AsyncSession:
"""插入测试基础数据并返回会话。"""
# 系统配置
configs = [
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","人工服务","真人","客服"]', description="举手关键词"),
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉","差劲","垃圾"]', description="愤怒关键词"),
SystemConfig(config_key="emotion_keywords_urgent", config_value='["","紧急","马上","立刻","赶紧"]', description="紧急关键词"),
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕","出错","丢失","完蛋"]', description="担忧关键词"),
SystemConfig(config_key="intervene_round_threshold", config_value="3", description="介入阈值"),
SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="基础加分"),
SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪加成"),
SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成"),
SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复加成"),
]
db_session.add_all(configs)
# 趣味话术
phrases = [
FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1),
FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1),
]
db_session.add_all(phrases)
# 审批链接
links = [
ApprovalLink(category="IT", title="软件安装申请", url="https://example.com/software", sort_order=1),
ApprovalLink(category="HR", title="入职手续", url="https://example.com/onboarding", sort_order=2),
]
db_session.add_all(links)
# 软件下载
downloads = [
SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com", sort_order=1),
SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com", sort_order=2),
]
db_session.add_all(downloads)
await db_session.flush()
return db_session
# =============================================================================
# 辅助函数
# =============================================================================
def create_test_conversation(
employee_id: str = "test_employee_001",
employee_name: str = "测试员工",
status: str = "queued",
is_vip: bool = False,
is_pinned: bool = False,
is_todo: bool = False,
urgency_score: int = 1,
tags: Optional[Dict] = None,
) -> Conversation:
"""创建测试用的会话对象。"""
return Conversation(
employee_id=employee_id,
employee_name=employee_name,
department="技术部",
position="工程师",
level="",
status=status,
is_vip=is_vip,
is_pinned=is_pinned,
is_todo=is_todo,
urgency_score=urgency_score,
tags=tags or {},
last_message_at=datetime.now(),
last_message_summary="测试消息",
)
def create_test_agent(
user_id: str = "test_agent_001",
name: str = "测试坐席",
status: str = "online",
) -> Agent:
"""创建测试用的坐席对象。"""
return Agent(
user_id=user_id,
name=name,
status=status,
current_load=0,
max_load=5,
)