chore: initial baseline with P0-safety .gitignore
This commit is contained in:
@@ -0,0 +1,333 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 测试配置与公共 fixtures
|
||||
# =============================================================================
|
||||
# 说明:pytest 的全局 fixtures,包括:
|
||||
# 1. SQLite 内存数据库(替代 PostgreSQL)
|
||||
# 2. 模拟 Redis 客户端
|
||||
# 3. FastAPI 测试客户端
|
||||
# 4. 测试用数据库会话
|
||||
# =============================================================================
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Dict, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database import Base
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.models.funny_phrase import FunnyPhrase
|
||||
from app.models.approval_link import ApprovalLink
|
||||
from app.models.software_download import SoftwareDownload
|
||||
from app.models.quick_reply_template import QuickReplyTemplate
|
||||
from app.models.agent_note import AgentNote
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SQLite 内存数据库引擎
|
||||
# =============================================================================
|
||||
# 使用 aiosqlite 驱动的 SQLite 内存数据库替代 PostgreSQL
|
||||
# StaticPool 确保所有连接使用同一个内存数据库实例
|
||||
# =============================================================================
|
||||
|
||||
TEST_DATABASE_URL = "sqlite+aiosqlite://"
|
||||
|
||||
test_engine = create_async_engine(
|
||||
TEST_DATABASE_URL,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
test_session_factory = async_sessionmaker(
|
||||
test_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
# 为 SQLite 启用外键约束
|
||||
@event.listens_for(test_engine.sync_engine, "connect")
|
||||
def _set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 模拟 Redis 客户端
|
||||
# =============================================================================
|
||||
|
||||
class MockRedis:
|
||||
"""模拟 Redis 客户端,使用内存字典存储数据。"""
|
||||
|
||||
def __init__(self):
|
||||
self._data: Dict[str, str] = {}
|
||||
self._ttl: Dict[str, int] = {}
|
||||
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
value = self._data.get(key)
|
||||
if value is not None:
|
||||
return value.encode("utf-8") if isinstance(value, str) else value
|
||||
return None
|
||||
|
||||
async def setex(self, name: str, time: int, value: str) -> None:
|
||||
self._data[name] = value
|
||||
self._ttl[name] = time
|
||||
|
||||
async def set(self, name: str, value: str, **kwargs) -> Optional[bool]:
|
||||
"""模拟 Redis SET 命令,支持 nx 和 ex 参数。
|
||||
|
||||
Args:
|
||||
name: Redis key
|
||||
value: Redis value
|
||||
**kwargs:
|
||||
nx: SET IF NOT EXISTS — key 不存在时才设置,返回 True;已存在返回 None
|
||||
ex: 过期时间(秒)
|
||||
|
||||
Returns:
|
||||
nx=True 时:True=设置成功,None=key 已存在未设置
|
||||
其他情况:None(与真实 Redis SET 行为一致)
|
||||
"""
|
||||
nx = kwargs.get("nx", False)
|
||||
ex = kwargs.get("ex", None)
|
||||
|
||||
if nx:
|
||||
if name in self._data:
|
||||
return None # key 已存在,SET NX 未设置
|
||||
self._data[name] = value
|
||||
if ex is not None:
|
||||
self._ttl[name] = ex
|
||||
return True # 设置成功
|
||||
|
||||
self._data[name] = value
|
||||
if ex is not None:
|
||||
self._ttl[name] = ex
|
||||
return None
|
||||
|
||||
async def delete(self, *names) -> int:
|
||||
count = 0
|
||||
for name in names:
|
||||
if name in self._data:
|
||||
del self._data[name]
|
||||
count += 1
|
||||
return count
|
||||
|
||||
async def exists(self, *keys) -> int:
|
||||
return sum(1 for k in keys if k in self._data)
|
||||
|
||||
async def expire(self, name: str, time: int) -> bool:
|
||||
if name in self._data:
|
||||
self._ttl[name] = time
|
||||
return True
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self._data.clear()
|
||||
self._ttl.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""创建 session 级别的事件循环。"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_database():
|
||||
"""创建所有数据库表(session 级别,只执行一次)。"""
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield
|
||||
async with test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""提供干净的数据库会话,每个测试用例使用独立事务并在测试后回滚。"""
|
||||
async with test_session_factory() as session:
|
||||
# 开始一个嵌套事务
|
||||
nested = await session.begin_nested()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# 回滚嵌套事务,确保数据库干净
|
||||
if nested.is_active:
|
||||
await nested.rollback()
|
||||
# 清理会话
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis() -> MockRedis:
|
||||
"""提供模拟 Redis 客户端。"""
|
||||
return MockRedis()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncGenerator[AsyncClient, None]:
|
||||
"""提供 FastAPI 异步测试客户端。"""
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
async def _override_get_redis():
|
||||
return mock_redis
|
||||
|
||||
from app.main import create_app
|
||||
from app.database import get_db
|
||||
|
||||
app = create_app()
|
||||
|
||||
# 覆盖数据库依赖
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
# 模拟 Redis(同时 mock agents 和 h5 模块的 Redis 依赖)
|
||||
with patch("app.api.agents._get_redis", return_value=mock_redis):
|
||||
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
||||
# ------------------------------------------------------------------
|
||||
# Mock 外部服务:WecomService(企微API)和 AIService(AI大模型)
|
||||
# 为什么:测试中不应调用真实企微API/AI大模型
|
||||
# 怎么做:patch 类构造函数,返回配置了默认返回值的 mock 对象
|
||||
# ------------------------------------------------------------------
|
||||
mock_wecom = AsyncMock()
|
||||
# 企微消息发送:默认成功
|
||||
mock_wecom.send_message.return_value = {"errcode": 0, "errmsg": "ok"}
|
||||
# 企微通讯录查询:动态返回(根据传入的 user_id 生成对应的名称)
|
||||
# 为什么:坐席登录时会调用 get_user_info 获取员工姓名
|
||||
# 如果返回固定名字,登录接口会用 mock 名字覆盖请求中的 name 参数
|
||||
async def _mock_get_user_info(user_id: str, **kwargs):
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"name": f"用户{user_id}",
|
||||
"department": "测试部",
|
||||
"avatar": "",
|
||||
}
|
||||
mock_wecom.get_user_info.side_effect = _mock_get_user_info
|
||||
mock_wecom.get_department_users.return_value = []
|
||||
|
||||
mock_ai = AsyncMock()
|
||||
mock_ai.generate_response.return_value = "这是AI的模拟回复"
|
||||
|
||||
# Patch WecomService 类(端点函数中会新建实例)
|
||||
# 注意:只 patch 模块中实际引用的名字
|
||||
# conversations.py 导入了 WecomService,但没有导入 AIService
|
||||
with patch("app.api.conversations.WecomService", return_value=mock_wecom):
|
||||
# h5.py 和 agents.py 也需要 patch
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
with patch("app.api.agents.WecomService", return_value=mock_wecom):
|
||||
with patch("app.api.agents._get_redis", return_value=mock_redis):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_db(db_session: AsyncSession) -> AsyncSession:
|
||||
"""插入测试基础数据并返回会话。"""
|
||||
# 系统配置
|
||||
configs = [
|
||||
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","人工服务","真人","客服"]', description="举手关键词"),
|
||||
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉","差劲","垃圾"]', description="愤怒关键词"),
|
||||
SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上","立刻","赶紧"]', description="紧急关键词"),
|
||||
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕","出错","丢失","完蛋"]', description="担忧关键词"),
|
||||
SystemConfig(config_key="intervene_round_threshold", config_value="3", description="介入阈值"),
|
||||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="基础加分"),
|
||||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪加成"),
|
||||
SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成"),
|
||||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复加成"),
|
||||
]
|
||||
db_session.add_all(configs)
|
||||
|
||||
# 趣味话术
|
||||
phrases = [
|
||||
FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1),
|
||||
FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1),
|
||||
]
|
||||
db_session.add_all(phrases)
|
||||
|
||||
# 审批链接
|
||||
links = [
|
||||
ApprovalLink(category="IT", title="软件安装申请", url="https://example.com/software", sort_order=1),
|
||||
ApprovalLink(category="HR", title="入职手续", url="https://example.com/onboarding", sort_order=2),
|
||||
]
|
||||
db_session.add_all(links)
|
||||
|
||||
# 软件下载
|
||||
downloads = [
|
||||
SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com", sort_order=1),
|
||||
SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com", sort_order=2),
|
||||
]
|
||||
db_session.add_all(downloads)
|
||||
|
||||
await db_session.flush()
|
||||
return db_session
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
|
||||
def create_test_conversation(
|
||||
employee_id: str = "test_employee_001",
|
||||
employee_name: str = "测试员工",
|
||||
status: str = "queued",
|
||||
is_vip: bool = False,
|
||||
is_pinned: bool = False,
|
||||
is_todo: bool = False,
|
||||
urgency_score: int = 1,
|
||||
tags: Optional[Dict] = None,
|
||||
) -> Conversation:
|
||||
"""创建测试用的会话对象。"""
|
||||
return Conversation(
|
||||
employee_id=employee_id,
|
||||
employee_name=employee_name,
|
||||
department="技术部",
|
||||
position="工程师",
|
||||
level="",
|
||||
status=status,
|
||||
is_vip=is_vip,
|
||||
is_pinned=is_pinned,
|
||||
is_todo=is_todo,
|
||||
urgency_score=urgency_score,
|
||||
tags=tags or {},
|
||||
last_message_at=datetime.now(),
|
||||
last_message_summary="测试消息",
|
||||
)
|
||||
|
||||
|
||||
def create_test_agent(
|
||||
user_id: str = "test_agent_001",
|
||||
name: str = "测试坐席",
|
||||
status: str = "online",
|
||||
) -> Agent:
|
||||
"""创建测试用的坐席对象。"""
|
||||
return Agent(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
status=status,
|
||||
current_load=0,
|
||||
max_load=5,
|
||||
)
|
||||
@@ -0,0 +1,213 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 坐席认证与管理测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 坐席登录(新坐席注册 + 已有坐席重新登录)
|
||||
# 2. Token 存 Redis(验证 TTL 和格式)
|
||||
# 3. 获取当前坐席信息(有效 Token / 无效 Token / 过期 Token)
|
||||
# 4. 更新坐席状态(online/busy/offline)
|
||||
# 5. 无效状态值校验
|
||||
# 6. 获取坐席列表
|
||||
# 7. 缺少 Authorization 头返回未授权
|
||||
# =============================================================================
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.models.agent import Agent
|
||||
from tests.conftest import create_test_agent, MockRedis
|
||||
|
||||
|
||||
class TestAgentLogin:
|
||||
"""测试坐席登录。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_new_agent(self, client, db_session, mock_redis):
|
||||
"""验证新坐席首次登录自动注册。"""
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "new_agent_001", "name": "新坐席"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["user_id"] == "new_agent_001"
|
||||
assert data["data"]["name"] == "新坐席"
|
||||
assert data["data"]["status"] == "online"
|
||||
assert "token" in data["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_existing_agent(self, client, db_session, mock_redis):
|
||||
"""验证已有坐席重新登录更新信息。"""
|
||||
# 先创建坐席
|
||||
agent = create_test_agent(user_id="exist_agent", name="旧名字")
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "exist_agent", "name": "新名字"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["data"]["name"] == "新名字"
|
||||
assert data["data"]["status"] == "online"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_returns_token(self, client, db_session, mock_redis):
|
||||
"""验证登录返回 Token 存入 Redis。"""
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "token_test_agent", "name": "Token测试"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "token" in data["data"]
|
||||
|
||||
# 验证 Redis 中存储了 token(key 格式:agent:token:{token})
|
||||
token = data["data"]["token"]
|
||||
redis_key = f"agent:token:{token}"
|
||||
stored_value = await mock_redis.get(redis_key)
|
||||
assert stored_value is not None
|
||||
|
||||
|
||||
class TestAgentAuthentication:
|
||||
"""测试坐席认证。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_me_with_valid_token(self, client, db_session, mock_redis):
|
||||
"""验证有效 Token 获取坐席信息。"""
|
||||
# 先登录获取 token
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "me_test_agent", "name": "我测试"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
# 用 token 获取坐席信息
|
||||
response = await client.get(
|
||||
"/agents/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["user_id"] == "me_test_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_me_without_token(self, client, db_session, mock_redis):
|
||||
"""验证缺少 Token 返回未授权。"""
|
||||
response = await client.get("/agents/me")
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002 # ERR_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_me_with_invalid_token(self, client, db_session, mock_redis):
|
||||
"""验证无效 Token 返回未授权。"""
|
||||
response = await client.get(
|
||||
"/agents/me",
|
||||
headers={"Authorization": "Bearer invalid_token_12345"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
class TestAgentStatusUpdate:
|
||||
"""测试坐席状态更新。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_to_busy(self, client, db_session, mock_redis):
|
||||
"""验证更新坐席状态为忙碌。"""
|
||||
# 先登录
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "status_test_agent", "name": "状态测试"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
# 更新状态
|
||||
response = await client.put(
|
||||
"/agents/me/status",
|
||||
json={"status": "busy"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["data"]["status"] == "busy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_to_offline(self, client, db_session, mock_redis):
|
||||
"""验证更新坐席状态为离线。"""
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "offline_test_agent", "name": "离线测试"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
response = await client.put(
|
||||
"/agents/me/status",
|
||||
json={"status": "offline"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["data"]["status"] == "offline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_invalid_value(self, client, db_session, mock_redis):
|
||||
"""验证无效状态值返回校验错误。"""
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "invalid_status_agent", "name": "无效状态测试"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
response = await client.put(
|
||||
"/agents/me/status",
|
||||
json={"status": "invalid_status"},
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
# Pydantic 校验应返回 422
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestAgentList:
|
||||
"""测试坐席列表。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents(self, client, db_session, mock_redis):
|
||||
"""验证获取坐席列表。"""
|
||||
agent1 = create_test_agent(user_id="list_agent_1", name="坐席一")
|
||||
agent2 = create_test_agent(user_id="list_agent_2", name="坐席二")
|
||||
db_session.add_all([agent1, agent2])
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.get("/agents")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert len(data["data"]["items"]) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_agents_by_status(self, client, db_session, mock_redis):
|
||||
"""验证按状态过滤坐席列表。"""
|
||||
online_agent = create_test_agent(user_id="online_filter_agent", name="在线坐席", status="online")
|
||||
offline_agent = create_test_agent(user_id="offline_filter_agent", name="离线坐席", status="offline")
|
||||
db_session.add_all([online_agent, offline_agent])
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.get("/agents?status=online")
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
for item in data["data"]["items"]:
|
||||
assert item["status"] == "online"
|
||||
@@ -0,0 +1,143 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — API 基础验证测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 健康检查端点
|
||||
# 2. 统一响应格式(code/data/message)
|
||||
# 3. CORS 配置
|
||||
# 4. 404 路由
|
||||
# 5. AppException 全局异常处理
|
||||
# 6. success_response / error_response 工具函数
|
||||
# =============================================================================
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.utils.response import success_response, error_response, AppException
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""测试健康检查端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, client, db_session):
|
||||
"""验证 /health 端点返回正常状态。"""
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert "service" in data
|
||||
|
||||
|
||||
class TestUnifiedResponseFormat:
|
||||
"""测试统一响应格式。"""
|
||||
|
||||
def test_success_response_format(self):
|
||||
"""验证成功响应格式:{code: 0, data: {}, message: "success"}。"""
|
||||
result = success_response(data={"key": "value"})
|
||||
|
||||
assert result["code"] == 0
|
||||
assert result["data"] == {"key": "value"}
|
||||
assert result["message"] == "success"
|
||||
|
||||
def test_success_response_default_data(self):
|
||||
"""验证成功响应默认 data 为 None。"""
|
||||
result = success_response()
|
||||
|
||||
assert result["code"] == 0
|
||||
assert result["data"] is None
|
||||
|
||||
def test_success_response_custom_message(self):
|
||||
"""验证成功响应自定义消息。"""
|
||||
result = success_response(message="操作成功")
|
||||
|
||||
assert result["message"] == "操作成功"
|
||||
|
||||
def test_error_response_format(self):
|
||||
"""验证错误响应格式:{code: N, data: null, message: "错误信息"}。"""
|
||||
result = error_response(1001, "参数错误")
|
||||
|
||||
assert result["code"] == 1001
|
||||
assert result["data"] is None
|
||||
assert result["message"] == "参数错误"
|
||||
|
||||
def test_error_response_with_data(self):
|
||||
"""验证错误响应可附带额外数据。"""
|
||||
result = error_response(1001, "校验失败", data={"field": "email"})
|
||||
|
||||
assert result["data"] == {"field": "email"}
|
||||
|
||||
|
||||
class TestAppException:
|
||||
"""测试业务异常类。"""
|
||||
|
||||
def test_app_exception_attributes(self):
|
||||
"""验证 AppException 包含 code/message/data 属性。"""
|
||||
exc = AppException(1002, "未授权")
|
||||
|
||||
assert exc.code == 1002
|
||||
assert exc.message == "未授权"
|
||||
assert exc.data is None
|
||||
|
||||
def test_app_exception_with_data(self):
|
||||
"""验证 AppException 可附带数据。"""
|
||||
exc = AppException(1001, "参数错误", data={"field": "id"})
|
||||
|
||||
assert exc.data == {"field": "id"}
|
||||
|
||||
def test_app_exception_is_exception(self):
|
||||
"""验证 AppException 是 Exception 的子类。"""
|
||||
exc = AppException(1001, "测试")
|
||||
assert isinstance(exc, Exception)
|
||||
|
||||
def test_predefined_error_constants(self):
|
||||
"""验证预定义错误常量。"""
|
||||
from app.utils.response import (
|
||||
ERR_PARAMS, ERR_UNAUTHORIZED, ERR_NOT_FOUND,
|
||||
ERR_FORBIDDEN, ERR_INTERNAL,
|
||||
)
|
||||
|
||||
assert ERR_PARAMS.code == 1001
|
||||
assert ERR_UNAUTHORIZED.code == 1002
|
||||
assert ERR_NOT_FOUND.code == 1003
|
||||
assert ERR_FORBIDDEN.code == 1004
|
||||
assert ERR_INTERNAL.code == 1005
|
||||
|
||||
|
||||
class TestAPIRoutes:
|
||||
"""测试 API 路由基础。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_404_not_found(self, client, db_session):
|
||||
"""验证访问不存在的路由返回 404。"""
|
||||
response = await client.get("/nonexistent-route")
|
||||
assert response.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_prefix(self, client, db_session):
|
||||
"""验证 API 路径前缀为 /api。"""
|
||||
# /api/agents 是有效路由
|
||||
response = await client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversations_list_response_format(self, client, db_session, mock_redis):
|
||||
"""验证会话列表 API 返回统一响应格式。"""
|
||||
# 先登录坐席获取 token(/api/conversations 需要 get_current_agent 认证)
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "basic_test_agent", "name": "基础测试坐席"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
response = await client.get(
|
||||
"/conversations",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert "code" in data
|
||||
assert "data" in data
|
||||
assert "message" in data
|
||||
assert data["code"] == 0
|
||||
@@ -0,0 +1,834 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 摇人多坐席协作功能 测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 一、邀请协作(POST /api/conversations/{id}/invite)
|
||||
# 1. 成功邀请:collaborating_agent_ids 更新,WS广播,WS定向推送
|
||||
# 2. 邀请已结单会话 → 3002
|
||||
# 3. 邀请未接单会话(queued)→ 3020
|
||||
# 4. 非主责/协作坐席邀请 → 3021
|
||||
# 5. 邀请主责坐席本人 → 3022
|
||||
# 6. 邀请已在协作中的坐席 → 3023
|
||||
# 7. 邀请离线坐席 → 3024
|
||||
# 8. 协作坐席也可以摇人(再摇第三人)
|
||||
# 9. 邀请不存在的坐席 → 3004
|
||||
# 10. 邀请不存在的会话 → 3003
|
||||
#
|
||||
# 二、退出协作(POST /api/conversations/{id}/leave)
|
||||
# 1. 成功退出:从 collaborating_agent_ids 移除,WS 广播
|
||||
# 2. 主责坐席尝试退出 → 3025
|
||||
# 3. 非协作坐席退出 → 3026
|
||||
# 4. 退出后清空当前选中会话
|
||||
#
|
||||
# 三、列表集成测试
|
||||
# 1. collaborating_agent_ids 字段正确
|
||||
# 2. collaborating_agent_names 姓名映射正确
|
||||
# 3. is_collaborator 字段正确
|
||||
# 4. 协作坐席仍能查看和回复
|
||||
#
|
||||
# 四、权限矩阵验证(端到端)
|
||||
# 1. 协作坐席不能结单
|
||||
# 2. 协作坐席不能转接
|
||||
# 3. 协作坐席不占负载
|
||||
# =============================================================================
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.services.session_service import SessionService
|
||||
from app.utils.response import AppException
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
|
||||
async def login_agent(client: AsyncClient, user_id: str, name: str) -> dict:
|
||||
"""登录坐席并返回认证头字典。"""
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": user_id, "name": name},
|
||||
)
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
async def create_serving_conversation(
|
||||
db_session: AsyncSession,
|
||||
employee_id: str = "emp_001",
|
||||
agent_user_id: str = "agent_owner",
|
||||
collab_ids: list = None,
|
||||
) -> Conversation:
|
||||
"""创建一个 serving 状态且有主责坐席的会话(可选已有协作坐席)。"""
|
||||
conv = create_test_conversation(
|
||||
employee_id=employee_id,
|
||||
status="serving",
|
||||
)
|
||||
conv.assigned_agent_id = agent_user_id
|
||||
conv.collaborating_agent_ids = collab_ids or []
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
return conv
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 一、邀请协作测试
|
||||
# =============================================================================
|
||||
|
||||
class TestInviteCollaborator:
|
||||
"""测试邀请协作接口 POST /api/conversations/{id}/invite。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_success_updates_collaborating_ids(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功邀请:collaborating_agent_ids 更新,WS广播+定向推送。
|
||||
|
||||
场景:坐席A(owner)在处理会话,邀请在线坐席B加入协作。
|
||||
"""
|
||||
# 创建坐席
|
||||
owner = create_test_agent(user_id="owner_001", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_001", name="坐席B", status="online")
|
||||
db_session.add_all([owner, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建 serving 会话,分配给 owner
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_invite", agent_user_id="owner_001"
|
||||
)
|
||||
|
||||
# 坐席A 登录并发起邀请
|
||||
headers = await login_agent(client, "owner_001", "坐席A")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock) as mock_broadcast, \
|
||||
patch("app.services.ws_manager.manager.send_to_agent", new_callable=AsyncMock) as mock_send:
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "invitee_001"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
|
||||
# 验证 collaborating_agent_ids 包含被邀请坐席
|
||||
result = data["data"]
|
||||
assert "invitee_001" in result["collaborating_agent_ids"]
|
||||
|
||||
# 验证 WS 广播被调用(collaborator_joined)
|
||||
mock_broadcast.assert_called_once()
|
||||
broadcast_msg = mock_broadcast.call_args[0][0]
|
||||
assert broadcast_msg["type"] == "collaborator_joined"
|
||||
assert broadcast_msg["data"]["agent_id"] == "invitee_001"
|
||||
assert broadcast_msg["data"]["inviter_agent_id"] == "owner_001"
|
||||
|
||||
# 验证 WS 定向推送被调用(collaborator_invited)
|
||||
mock_send.assert_called_once_with("invitee_001", mock_send.call_args[0][1])
|
||||
sent_msg = mock_send.call_args[0][1]
|
||||
assert sent_msg["type"] == "collaborator_invited"
|
||||
assert sent_msg["data"]["invitee_agent_id"] == "invitee_001"
|
||||
|
||||
# 验证数据库持久化
|
||||
stmt = select(Conversation).where(Conversation.id == conv.id)
|
||||
result_db = await db_session.execute(stmt)
|
||||
db_conv = result_db.scalars().first()
|
||||
assert "invitee_001" in db_conv.collaborating_agent_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_resolved_conversation_error_3002(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能邀请已结单会话 → 3002。"""
|
||||
owner = create_test_agent(user_id="owner_resolved", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_resolved", name="坐席B", status="online")
|
||||
db_session.add_all([owner, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_resolved_invite", status="resolved"
|
||||
)
|
||||
conv.assigned_agent_id = "owner_resolved"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "owner_resolved", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "invitee_resolved"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3002 # ERR_CONVERSATION_RESOLVED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_queued_conversation_error_3020(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能邀请未接单(queued)的会话 → 3020。"""
|
||||
owner = create_test_agent(user_id="owner_queued", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_queued", name="坐席B", status="online")
|
||||
db_session.add_all([owner, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
# queued 状态,未分配坐席 → 应先报 3021(不是 owner/collaborator)
|
||||
# 但如果有 assigned_agent_id=owner,则是 queued 状态 → 3020
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_queued_invite", status="queued"
|
||||
)
|
||||
conv.assigned_agent_id = "owner_queued"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "owner_queued", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "invitee_queued"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3020
|
||||
assert "服务中" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_by_non_owner_error_3021(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证非主责/协作坐席不能摇人 → 3021。"""
|
||||
owner = create_test_agent(user_id="owner_3021", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_3021", name="坐席B", status="online")
|
||||
stranger = create_test_agent(user_id="stranger_3021", name="路过的坐席", status="online")
|
||||
db_session.add_all([owner, invitee, stranger])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_3021", agent_user_id="owner_3021"
|
||||
)
|
||||
|
||||
# 用路过的坐席登录(既不是主责也不是协作坐席)
|
||||
headers = await login_agent(client, "stranger_3021", "路过的坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "invitee_3021"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3021
|
||||
assert "摇人" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_owner_self_error_3022(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能邀请主责坐席本人 → 3022。"""
|
||||
owner = create_test_agent(user_id="owner_3022", name="坐席A", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_3022", agent_user_id="owner_3022"
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_3022", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "owner_3022"}, # 邀请自己
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3022
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_duplicate_collaborator_error_3023(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能重复邀请已在协作中的坐席 → 3023。"""
|
||||
owner = create_test_agent(user_id="owner_3023", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_3023", name="坐席B", status="online")
|
||||
db_session.add_all([owner, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
# 坐席B 已在协作列表中
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_3023",
|
||||
agent_user_id="owner_3023",
|
||||
collab_ids=["invitee_3023"],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_3023", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "invitee_3023"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3023
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_offline_agent_error_3024(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能邀请离线坐席 → 3024。"""
|
||||
owner = create_test_agent(user_id="owner_3024", name="坐席A", status="online")
|
||||
offline_agent = create_test_agent(user_id="offline_3024", name="离线坐席", status="offline")
|
||||
db_session.add_all([owner, offline_agent])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_3024", agent_user_id="owner_3024"
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_3024", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "offline_3024"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3024
|
||||
assert "不在线" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_nonexistent_agent_error_3004(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证邀请不存在的坐席 → 3004。"""
|
||||
owner = create_test_agent(user_id="owner_3004", name="坐席A", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_3004", agent_user_id="owner_3004"
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_3004", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "nonexistent_agent"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3004 # ERR_AGENT_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_nonexistent_conversation_error_3003(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证邀请不存在的会话 → 3003。"""
|
||||
agent = create_test_agent(user_id="agent_3003", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="invitee_3003", name="坐席B", status="online")
|
||||
db_session.add_all([agent, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
fake_id = str(uuid.uuid4())
|
||||
headers = await login_agent(client, "agent_3003", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{fake_id}/invite",
|
||||
json={"agent_id": "invitee_3003"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3003 # ERR_CONVERSATION_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collaborator_can_also_invite_others(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证协作坐席也可以摇人(再摇第三人加入)。
|
||||
|
||||
场景:坐席A(owner)邀请坐席B → 坐席B 再摇坐席C
|
||||
"""
|
||||
owner = create_test_agent(user_id="chain_owner", name="坐席A", status="online")
|
||||
collab1 = create_test_agent(user_id="chain_collab1", name="坐席B", status="online")
|
||||
collab2 = create_test_agent(user_id="chain_collab2", name="坐席C", status="online")
|
||||
db_session.add_all([owner, collab1, collab2])
|
||||
await db_session.flush()
|
||||
|
||||
# 坐席A 创建会话,邀请坐席B
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_chain",
|
||||
agent_user_id="chain_owner",
|
||||
collab_ids=["chain_collab1"],
|
||||
)
|
||||
|
||||
# 坐席B 登录并发起邀请坐席C
|
||||
headers = await login_agent(client, "chain_collab1", "坐席B")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock), \
|
||||
patch("app.services.ws_manager.manager.send_to_agent", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "chain_collab2"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
|
||||
# 验证 collaborating_agent_ids 包含两个协作坐席
|
||||
result = data["data"]
|
||||
assert "chain_collab1" in result["collaborating_agent_ids"]
|
||||
assert "chain_collab2" in result["collaborating_agent_ids"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_without_auth_returns_unauthorized(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未登录时邀请返回未授权错误。"""
|
||||
owner = create_test_agent(user_id="noauth_owner", name="坐席A", status="online")
|
||||
invitee = create_test_agent(user_id="noauth_invitee", name="坐席B", status="online")
|
||||
db_session.add_all([owner, invitee])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_noauth", agent_user_id="noauth_owner"
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "noauth_invitee"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002 # ERR_UNAUTHORIZED
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 二、退出协作测试
|
||||
# =============================================================================
|
||||
|
||||
class TestLeaveCollaboration:
|
||||
"""测试退出协作接口 POST /api/conversations/{id}/leave。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_success_removes_from_list(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功退出:从 collaborating_agent_ids 移除,WS 广播。"""
|
||||
owner = create_test_agent(user_id="leave_owner", name="坐席A", status="online")
|
||||
collab = create_test_agent(user_id="leave_collab", name="坐席B", status="online")
|
||||
db_session.add_all([owner, collab])
|
||||
await db_session.flush()
|
||||
|
||||
# 坐席B 已在协作列表中
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_leave",
|
||||
agent_user_id="leave_owner",
|
||||
collab_ids=["leave_collab"],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "leave_collab", "坐席B")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock) as mock_broadcast:
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
|
||||
# 验证 collaborating_agent_ids 不再包含坐席B
|
||||
result = data["data"]
|
||||
assert "leave_collab" not in result["collaborating_agent_ids"]
|
||||
|
||||
# 验证 WS 广播被调用(collaborator_left)
|
||||
mock_broadcast.assert_called_once()
|
||||
broadcast_msg = mock_broadcast.call_args[0][0]
|
||||
assert broadcast_msg["type"] == "collaborator_left"
|
||||
assert broadcast_msg["data"]["agent_id"] == "leave_collab"
|
||||
|
||||
# 验证数据库持久化
|
||||
stmt = select(Conversation).where(Conversation.id == conv.id)
|
||||
result_db = await db_session.execute(stmt)
|
||||
db_conv = result_db.scalars().first()
|
||||
assert "leave_collab" not in db_conv.collaborating_agent_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_owner_error_3025(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证主责坐席不能退出协作 → 3025。"""
|
||||
owner = create_test_agent(user_id="leave_owner_3025", name="坐席A", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_leave_owner",
|
||||
agent_user_id="leave_owner_3025",
|
||||
collab_ids=["some_collab"],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "leave_owner_3025", "坐席A")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3025
|
||||
assert "主责坐席" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_non_collaborator_error_3026(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不在协作列表中的坐席不能退出 → 3026。"""
|
||||
owner = create_test_agent(user_id="leave_owner_3026", name="坐席A", status="online")
|
||||
stranger = create_test_agent(user_id="stranger_3026", name="路过的坐席", status="online")
|
||||
db_session.add_all([owner, stranger])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_leave_stranger",
|
||||
agent_user_id="leave_owner_3026",
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "stranger_3026", "路过的坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3026
|
||||
assert "协作列表" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_without_auth_returns_unauthorized(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未登录时退出返回未授权错误。"""
|
||||
owner = create_test_agent(user_id="noauth_leave_owner", name="坐席A", status="online")
|
||||
collab = create_test_agent(user_id="noauth_leave_collab", name="坐席B", status="online")
|
||||
db_session.add_all([owner, collab])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_noauth_leave",
|
||||
agent_user_id="noauth_leave_owner",
|
||||
collab_ids=["noauth_leave_collab"],
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002 # ERR_UNAUTHORIZED
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 三、列表集成测试
|
||||
# =============================================================================
|
||||
|
||||
class TestCollaborationListIntegration:
|
||||
"""测试会话列表接口的协作字段集成。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_includes_collaboration_fields(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证列表接口返回 collaborating_agent_ids 和 _names 字段。"""
|
||||
owner = create_test_agent(user_id="list_owner", name="坐席A", status="online")
|
||||
collab1 = create_test_agent(user_id="list_collab1", name="坐席B", status="online")
|
||||
collab2 = create_test_agent(user_id="list_collab2", name="坐席C", status="online")
|
||||
db_session.add_all([owner, collab1, collab2])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_list_collab",
|
||||
agent_user_id="list_owner",
|
||||
collab_ids=["list_collab1", "list_collab2"],
|
||||
)
|
||||
|
||||
# 以坐席A身份查看列表
|
||||
headers = await login_agent(client, "list_owner", "坐席A")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
items = data["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
conv_item = item_map[str(conv.id)]
|
||||
|
||||
# 验证 collaborating_agent_ids
|
||||
assert "list_collab1" in conv_item["collaborating_agent_ids"]
|
||||
assert "list_collab2" in conv_item["collaborating_agent_ids"]
|
||||
assert len(conv_item["collaborating_agent_ids"]) == 2
|
||||
|
||||
# 验证 collaborating_agent_names
|
||||
assert conv_item["collaborating_agent_names"]["list_collab1"] == "坐席B"
|
||||
assert conv_item["collaborating_agent_names"]["list_collab2"] == "坐席C"
|
||||
|
||||
# 验证 is_collaborator(坐席A是主责不是协作坐席)
|
||||
assert conv_item["is_collaborator"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_is_collaborator_field_correctness(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证 is_collaborator 字段标注正确。
|
||||
|
||||
- 主责坐席 → is_collaborator=False
|
||||
- 协作坐席(且非主责)→ is_collaborator=True
|
||||
- 既非主责也非协作 → is_collaborator=False
|
||||
"""
|
||||
owner = create_test_agent(user_id="iscoll_owner", name="主责坐席", status="online")
|
||||
collab = create_test_agent(user_id="iscoll_collab", name="协作坐席", status="online")
|
||||
stranger = create_test_agent(user_id="iscoll_stranger", name="路人坐席", status="online")
|
||||
db_session.add_all([owner, collab, stranger])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_iscoll",
|
||||
agent_user_id="iscoll_owner",
|
||||
collab_ids=["iscoll_collab"],
|
||||
)
|
||||
|
||||
# 主责坐席查看 → is_collaborator=False
|
||||
headers_owner = await login_agent(client, "iscoll_owner", "主责坐席")
|
||||
resp = await client.get("/conversations", headers=headers_owner)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
assert item_map[str(conv.id)]["is_collaborator"] is False
|
||||
|
||||
# 协作坐席查看 → is_collaborator=True
|
||||
headers_collab = await login_agent(client, "iscoll_collab", "协作坐席")
|
||||
resp = await client.get("/conversations", headers=headers_collab)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
assert item_map[str(conv.id)]["is_collaborator"] is True
|
||||
|
||||
# 路人坐席查看 → is_collaborator=False
|
||||
headers_stranger = await login_agent(client, "iscoll_stranger", "路人坐席")
|
||||
resp = await client.get("/conversations", headers=headers_stranger)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
assert item_map[str(conv.id)]["is_collaborator"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_no_collaborators_returns_empty_arrays(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证无协作坐席时返回空数组和空对象。"""
|
||||
owner = create_test_agent(user_id="empty_owner", name="坐席A", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_empty_collab", agent_user_id="empty_owner"
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "empty_owner", "坐席A")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
|
||||
data = response.json()
|
||||
items = data["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
conv_item = item_map[str(conv.id)]
|
||||
assert conv_item["collaborating_agent_ids"] == []
|
||||
assert conv_item["collaborating_agent_names"] == {}
|
||||
assert conv_item["is_collaborator"] is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 四、权限矩阵验证(端到端)
|
||||
# =============================================================================
|
||||
|
||||
class TestCollaborationPermissions:
|
||||
"""测试协作坐席的权限边界。
|
||||
|
||||
协作坐席可以:查看会话、发送回复、摇人(再邀请)
|
||||
协作坐席不能:结单、转接、置顶/代办
|
||||
协作坐席不占负载。
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collaborator_does_not_count_load(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证协作坐席加入后负载不变(不占负载)。
|
||||
|
||||
主责坐席 current_load 应保持为1,协作坐席 current_load 保持不变。
|
||||
"""
|
||||
owner = create_test_agent(user_id="load_owner", name="坐席A", status="online")
|
||||
owner.current_load = 1
|
||||
|
||||
collab = create_test_agent(user_id="load_collab", name="坐席B", status="online")
|
||||
collab.current_load = 0
|
||||
|
||||
db_session.add_all([owner, collab])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_load", agent_user_id="load_owner"
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "load_owner", "坐席A")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock), \
|
||||
patch("app.services.ws_manager.manager.send_to_agent", new_callable=AsyncMock):
|
||||
await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "load_collab"},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 验证主责坐席 load 不变(=1)
|
||||
stmt = select(Agent).where(Agent.user_id == "load_owner")
|
||||
result = await db_session.execute(stmt)
|
||||
db_owner = result.scalars().first()
|
||||
assert db_owner.current_load == 1
|
||||
|
||||
# 验证协作坐席 load 不变(=0)
|
||||
stmt = select(Agent).where(Agent.user_id == "load_collab")
|
||||
result = await db_session.execute(stmt)
|
||||
db_collab = result.scalars().first()
|
||||
assert db_collab.current_load == 0 # 协作不占负载
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collaborator_cannot_resolve_conversation(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证协作坐席不能结单。
|
||||
|
||||
场景:坐席A(owner)邀请坐席B协作 → 坐席B 尝试结单(应失败)
|
||||
"""
|
||||
owner = create_test_agent(user_id="perm_owner", name="坐席A", status="online")
|
||||
collab = create_test_agent(user_id="perm_collab", name="坐席B", status="online")
|
||||
db_session.add_all([owner, collab])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation(
|
||||
db_session,
|
||||
employee_id="emp_perm",
|
||||
agent_user_id="perm_owner",
|
||||
collab_ids=["perm_collab"],
|
||||
)
|
||||
|
||||
# 协作坐席登录并尝试结单
|
||||
headers = await login_agent(client, "perm_collab", "坐席B")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/resolve",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 协作坐席不是主责坐席,resolve 应返回 3027(只有主责坐席才能结单)
|
||||
assert data["code"] == 3027, f"协作坐席不应该能结单,期望 code=3027,实际 code={data['code']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_invite_leave_cycle(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""端到端测试:邀请→查看列表→退出→验证清理。
|
||||
|
||||
完整流程:
|
||||
1. 坐席A 邀请坐席B
|
||||
2. 坐席B 查看列表,确认协作会话出现
|
||||
3. 坐席A 邀请坐席C
|
||||
4. 坐席A 查看列表,验证两个协作坐席
|
||||
5. 坐席B 退出协作
|
||||
6. 验证坐席B 不再出现在协作列表中,坐席C 仍存在
|
||||
"""
|
||||
# 创建坐席
|
||||
owner = create_test_agent(user_id="e2e_owner", name="坐席A", status="online")
|
||||
collab_b = create_test_agent(user_id="e2e_b", name="坐席B", status="online")
|
||||
collab_c = create_test_agent(user_id="e2e_c", name="坐席C", status="online")
|
||||
db_session.add_all([owner, collab_b, collab_c])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建会话
|
||||
conv = await create_serving_conversation(
|
||||
db_session, employee_id="emp_e2e", agent_user_id="e2e_owner"
|
||||
)
|
||||
|
||||
headers_a = await login_agent(client, "e2e_owner", "坐席A")
|
||||
|
||||
# Step 1: 坐席A 邀请坐席B
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock), \
|
||||
patch("app.services.ws_manager.manager.send_to_agent", new_callable=AsyncMock):
|
||||
resp = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "e2e_b"},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.json()["code"] == 0
|
||||
|
||||
# Step 2: 坐席B 查看列表,确认协作会话出现
|
||||
headers_b = await login_agent(client, "e2e_b", "坐席B")
|
||||
resp = await client.get("/conversations", headers=headers_b)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
assert str(conv.id) in item_map
|
||||
assert item_map[str(conv.id)]["is_collaborator"] is True
|
||||
|
||||
# Step 3: 坐席A 邀请坐席C
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock), \
|
||||
patch("app.services.ws_manager.manager.send_to_agent", new_callable=AsyncMock):
|
||||
resp = await client.post(
|
||||
f"/conversations/{conv.id}/invite",
|
||||
json={"agent_id": "e2e_c"},
|
||||
headers=headers_a,
|
||||
)
|
||||
assert resp.json()["code"] == 0
|
||||
|
||||
# Step 4: 坐席A 查看列表,验证两个协作坐席
|
||||
resp = await client.get("/conversations", headers=headers_a)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
conv_item = item_map[str(conv.id)]
|
||||
assert "e2e_b" in conv_item["collaborating_agent_ids"]
|
||||
assert "e2e_c" in conv_item["collaborating_agent_ids"]
|
||||
assert conv_item["collaborating_agent_names"]["e2e_b"] == "坐席B"
|
||||
assert conv_item["collaborating_agent_names"]["e2e_c"] == "坐席C"
|
||||
|
||||
# Step 5: 坐席B 退出协作
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
resp = await client.post(
|
||||
f"/conversations/{conv.id}/leave",
|
||||
headers=headers_b,
|
||||
)
|
||||
assert resp.json()["code"] == 0
|
||||
|
||||
# Step 6: 验证坐席B 已移除,坐席C 仍存在
|
||||
resp = await client.get("/conversations", headers=headers_a)
|
||||
items = resp.json()["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
conv_item = item_map[str(conv.id)]
|
||||
assert "e2e_b" not in conv_item["collaborating_agent_ids"]
|
||||
assert "e2e_c" in conv_item["collaborating_agent_ids"]
|
||||
@@ -0,0 +1,749 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 坐席会话全局可见 + 接手功能 测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 一、会话列表接口(GET /api/conversations)
|
||||
# 1. 返回全部活跃会话(queued + serving + 其他坐席 serving)
|
||||
# 2. is_mine / assigned_agent_name / can_grab 字段标注正确
|
||||
# 3. N+1 查询优化(坐席信息批量查询)
|
||||
#
|
||||
# 二、接手接口(POST /api/conversations/{id}/grab)
|
||||
# 1. 成功接手:原坐席 load-1,新坐席 load+1,assigned_agent_id 切换
|
||||
# 2. 不能接手未分配坐席的会话 → 3011
|
||||
# 3. 不能接手自己的会话 → 3012
|
||||
# 4. 不能接手非 serving 状态会话 → 3013
|
||||
# 5. 不能接手已结单会话 → 3002
|
||||
# 6. 接手后 WebSocket 广播 conversation_updated
|
||||
#
|
||||
# 三、边界情况
|
||||
# 1. 满负荷坐席接手 → 3005
|
||||
# 2. 会话不存在 → 3003
|
||||
# 3. 接手成功后返回字段验证(is_mine=True, can_grab=False)
|
||||
# =============================================================================
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.services.session_service import SessionService
|
||||
from app.utils.response import AppException
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数:登录坐席并获取认证头
|
||||
# =============================================================================
|
||||
|
||||
async def login_agent(client: AsyncClient, user_id: str, name: str) -> dict:
|
||||
"""登录坐席并返回认证头字典。"""
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": user_id, "name": name},
|
||||
)
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
async def create_and_assign_conversation(
|
||||
db_session: AsyncSession,
|
||||
employee_id: str = "emp_001",
|
||||
agent_user_id: str = "agent_001",
|
||||
) -> Conversation:
|
||||
"""创建一个已分配坐席的 serving 状态会话。"""
|
||||
conv = create_test_conversation(
|
||||
employee_id=employee_id,
|
||||
status="serving",
|
||||
)
|
||||
conv.assigned_agent_id = agent_user_id
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
return conv
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 一、会话列表接口测试
|
||||
# =============================================================================
|
||||
|
||||
class TestConversationListGlobalVisibility:
|
||||
"""测试会话列表全局可见功能。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_returns_all_active_conversations(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证 GET /api/conversations 返回全部活跃会话。
|
||||
|
||||
场景:数据库中有 queued、serving(自己的)、serving(其他坐席的)三种会话,
|
||||
当前坐席应能看到所有这些会话。
|
||||
"""
|
||||
# 准备:创建坐席
|
||||
agent = create_test_agent(user_id="viewer_001", name="查看坐席")
|
||||
other_agent = create_test_agent(user_id="other_001", name="其他坐席")
|
||||
db_session.add_all([agent, other_agent])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建三种状态的会话
|
||||
conv_queued = create_test_conversation(
|
||||
employee_id="emp_queued", status="queued"
|
||||
)
|
||||
conv_my_serving = create_test_conversation(
|
||||
employee_id="emp_my", status="serving"
|
||||
)
|
||||
conv_my_serving.assigned_agent_id = "viewer_001"
|
||||
|
||||
conv_other_serving = create_test_conversation(
|
||||
employee_id="emp_other", status="serving"
|
||||
)
|
||||
conv_other_serving.assigned_agent_id = "other_001"
|
||||
|
||||
db_session.add_all([conv_queued, conv_my_serving, conv_other_serving])
|
||||
await db_session.flush()
|
||||
|
||||
# 登录并请求
|
||||
headers = await login_agent(client, "viewer_001", "查看坐席")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
items = data["data"]["items"]
|
||||
total = data["data"]["total"]
|
||||
|
||||
# 应至少包含我们创建的3个活跃会话
|
||||
assert total >= 3
|
||||
# 验证三种类型的会话都出现在结果中
|
||||
conv_ids = {item["id"] for item in items}
|
||||
assert str(conv_queued.id) in conv_ids
|
||||
assert str(conv_my_serving.id) in conv_ids
|
||||
assert str(conv_other_serving.id) in conv_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_mine_field_correctness(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证 is_mine 字段标注正确。
|
||||
|
||||
- 自己的会话 is_mine=True
|
||||
- 其他坐席的会话 is_mine=False
|
||||
- 未分配坐席的会话 is_mine=False
|
||||
"""
|
||||
agent = create_test_agent(user_id="mine_agent", name="我的坐席")
|
||||
other = create_test_agent(user_id="other_agent", name="他人坐席")
|
||||
db_session.add_all([agent, other])
|
||||
await db_session.flush()
|
||||
|
||||
conv_mine = create_test_conversation(
|
||||
employee_id="emp_mine", status="serving"
|
||||
)
|
||||
conv_mine.assigned_agent_id = "mine_agent"
|
||||
|
||||
conv_other = create_test_conversation(
|
||||
employee_id="emp_other2", status="serving"
|
||||
)
|
||||
conv_other.assigned_agent_id = "other_agent"
|
||||
|
||||
conv_unassigned = create_test_conversation(
|
||||
employee_id="emp_unassigned", status="queued"
|
||||
)
|
||||
|
||||
db_session.add_all([conv_mine, conv_other, conv_unassigned])
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "mine_agent", "我的坐席")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
data = response.json()
|
||||
items = data["data"]["items"]
|
||||
|
||||
# 构建一个 id → item 的映射
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
# 自己的会话 → is_mine=True
|
||||
assert item_map[str(conv_mine.id)]["is_mine"] is True
|
||||
# 其他坐席的会话 → is_mine=False
|
||||
assert item_map[str(conv_other.id)]["is_mine"] is False
|
||||
# 未分配坐席的会话 → is_mine=False
|
||||
assert item_map[str(conv_unassigned.id)]["is_mine"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assigned_agent_name_field(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证 assigned_agent_name 字段正确返回坐席姓名。
|
||||
|
||||
- 已分配坐席的会话应返回坐席姓名
|
||||
- 未分配坐席的会话应返回 None
|
||||
"""
|
||||
agent = create_test_agent(user_id="name_agent", name="坐席张三")
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
conv_assigned = create_test_conversation(
|
||||
employee_id="emp_assigned", status="serving"
|
||||
)
|
||||
conv_assigned.assigned_agent_id = "name_agent"
|
||||
|
||||
conv_unassigned = create_test_conversation(
|
||||
employee_id="emp_no_agent", status="queued"
|
||||
)
|
||||
|
||||
db_session.add_all([conv_assigned, conv_unassigned])
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "name_agent", "坐席张三")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
data = response.json()
|
||||
items = data["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
# 已分配的会话应包含坐席姓名
|
||||
assert item_map[str(conv_assigned.id)]["assigned_agent_name"] == "坐席张三"
|
||||
# 未分配的会话坐席姓名为 None
|
||||
assert item_map[str(conv_unassigned.id)]["assigned_agent_name"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_grab_field_correctness(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证 can_grab 字段标注正确。
|
||||
|
||||
can_grab = True 的条件:assigned_agent_id 非空 且 不是自己 且 status=serving
|
||||
- 其他坐席的 serving 会话 → can_grab=True
|
||||
- 自己的会话 → can_grab=False
|
||||
- 未分配的 queued 会话 → can_grab=False
|
||||
- 其他坐席的 queued 会话 → can_grab=False
|
||||
- 其他坐席的 resolved 会话 → can_grab=False
|
||||
"""
|
||||
agent = create_test_agent(user_id="grab_checker", name="检查坐席")
|
||||
other = create_test_agent(user_id="grab_other", name="他人坐席")
|
||||
db_session.add_all([agent, other])
|
||||
await db_session.flush()
|
||||
|
||||
# 其他坐席的 serving 会话 → 可接手
|
||||
conv_other_serving = create_test_conversation(
|
||||
employee_id="emp_other_serving", status="serving"
|
||||
)
|
||||
conv_other_serving.assigned_agent_id = "grab_other"
|
||||
|
||||
# 自己的会话 → 不可接手
|
||||
conv_my_serving = create_test_conversation(
|
||||
employee_id="emp_my_serving", status="serving"
|
||||
)
|
||||
conv_my_serving.assigned_agent_id = "grab_checker"
|
||||
|
||||
# 未分配的 queued 会话 → 不可接手
|
||||
conv_queued = create_test_conversation(
|
||||
employee_id="emp_queued_grab", status="queued"
|
||||
)
|
||||
|
||||
# 已结单会话 → 不可接手
|
||||
conv_resolved = create_test_conversation(
|
||||
employee_id="emp_resolved_grab", status="resolved"
|
||||
)
|
||||
conv_resolved.assigned_agent_id = "grab_other"
|
||||
|
||||
db_session.add_all([
|
||||
conv_other_serving, conv_my_serving, conv_queued, conv_resolved
|
||||
])
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "grab_checker", "检查坐席")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
data = response.json()
|
||||
items = data["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
# 其他坐席 serving → can_grab=True
|
||||
assert item_map[str(conv_other_serving.id)]["can_grab"] is True
|
||||
# 自己的会话 → can_grab=False
|
||||
assert item_map[str(conv_my_serving.id)]["can_grab"] is False
|
||||
# queued 会话 → can_grab=False
|
||||
assert item_map[str(conv_queued.id)]["can_grab"] is False
|
||||
# resolved 会话 → can_grab=False
|
||||
assert item_map[str(conv_resolved.id)]["can_grab"] is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 二、接手接口测试
|
||||
# =============================================================================
|
||||
|
||||
class TestGrabConversation:
|
||||
"""测试接手会话接口 POST /api/conversations/{id}/grab。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_success_switches_agent_and_load(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功接手:原坐席 load-1,新坐席 load+1,assigned_agent_id 切换。"""
|
||||
# 创建原坐席(已有1个会话)
|
||||
old_agent = create_test_agent(
|
||||
user_id="old_agent", name="原坐席", status="online"
|
||||
)
|
||||
old_agent.current_load = 1
|
||||
old_agent.max_load = 5
|
||||
|
||||
# 创建新坐席(准备接手)
|
||||
new_agent = create_test_agent(
|
||||
user_id="new_agent", name="新坐席", status="online"
|
||||
)
|
||||
new_agent.current_load = 0
|
||||
new_agent.max_load = 5
|
||||
|
||||
db_session.add_all([old_agent, new_agent])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建一个 serving 状态的会话,分配给原坐席
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_grab_success", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "old_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 新坐席登录并发起接手
|
||||
headers = await login_agent(client, "new_agent", "新坐席")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock) as mock_broadcast:
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
|
||||
# 验证会话的 assigned_agent_id 已切换
|
||||
assert data["data"]["assigned_agent_id"] == "new_agent"
|
||||
|
||||
# 验证原坐席 current_load 减 1
|
||||
stmt = select(Agent).where(Agent.user_id == "old_agent")
|
||||
result = await db_session.execute(stmt)
|
||||
refreshed_old = result.scalars().first()
|
||||
assert refreshed_old.current_load == 0 # 1 - 1 = 0
|
||||
|
||||
# 验证新坐席 current_load 加 1
|
||||
stmt = select(Agent).where(Agent.user_id == "new_agent")
|
||||
result = await db_session.execute(stmt)
|
||||
refreshed_new = result.scalars().first()
|
||||
assert refreshed_new.current_load == 1 # 0 + 1 = 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_no_agent_error_3011(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能接手未分配坐席的会话 → 3011。"""
|
||||
agent = create_test_agent(user_id="grab_no_agent_user", name="测试坐席")
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建一个 queued 状态(未分配坐席)的会话
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_no_agent", status="queued"
|
||||
)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "grab_no_agent_user", "测试坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3011
|
||||
assert "尚未分配坐席" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_self_error_3012(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能接手自己的会话 → 3012。"""
|
||||
agent = create_test_agent(user_id="grab_self_user", name="自接坐席")
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建一个分配给自己的 serving 会话
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_self_grab", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "grab_self_user"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "grab_self_user", "自接坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3012
|
||||
assert "不能接手自己的会话" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_not_serving_error_3013(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能接手非 serving 状态的会话 → 3013。"""
|
||||
other_agent = create_test_agent(user_id="other_for_3013", name="他人坐席")
|
||||
grabber = create_test_agent(user_id="grabber_for_3013", name="接手坐席")
|
||||
db_session.add_all([other_agent, grabber])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建一个 queued 状态但已分配坐席的会话(边界:assigned + queued)
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_not_serving", status="queued"
|
||||
)
|
||||
conv.assigned_agent_id = "other_for_3013"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "grabber_for_3013", "接手坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3013
|
||||
assert "只能接手服务中的会话" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_resolved_error_3002(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能接手已结单的会话 → 3002。
|
||||
|
||||
注意:源码中 resolved 检查在 status != serving 检查之前,
|
||||
所以 resolved 会优先命中 3002 而非 3013。
|
||||
"""
|
||||
other_agent = create_test_agent(user_id="other_for_3002", name="他人坐席")
|
||||
grabber = create_test_agent(user_id="grabber_for_3002", name="接手坐席")
|
||||
db_session.add_all([other_agent, grabber])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建已结单但分配了坐席的会话
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_resolved_grab_test", status="resolved"
|
||||
)
|
||||
conv.assigned_agent_id = "other_for_3002"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "grabber_for_3002", "接手坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# resolved 检查在 status != serving 之前,应返回 3002
|
||||
assert data["code"] == 3002
|
||||
assert "已结单" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_broadcasts_websocket(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证接手成功后广播 WebSocket conversation_updated 事件。"""
|
||||
old_agent = create_test_agent(
|
||||
user_id="ws_old_agent", name="原坐席", status="online"
|
||||
)
|
||||
old_agent.current_load = 1
|
||||
new_agent = create_test_agent(
|
||||
user_id="ws_new_agent", name="新坐席", status="online"
|
||||
)
|
||||
db_session.add_all([old_agent, new_agent])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_ws_grab", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "ws_old_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "ws_new_agent", "新坐席")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock) as mock_broadcast:
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 验证广播被调用
|
||||
mock_broadcast.assert_called_once()
|
||||
broadcast_data = mock_broadcast.call_args[0][0]
|
||||
assert broadcast_data["type"] == "conversation_updated"
|
||||
assert broadcast_data["data"]["conversation_id"] == str(conv.id)
|
||||
assert broadcast_data["data"]["old_agent_id"] == "ws_old_agent"
|
||||
assert broadcast_data["data"]["new_agent_id"] == "ws_new_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_success_response_fields(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证接手成功后返回的扩展字段:is_mine=True, can_grab=False, assigned_agent_name。"""
|
||||
old_agent = create_test_agent(
|
||||
user_id="resp_old_agent", name="原坐席", status="online"
|
||||
)
|
||||
old_agent.current_load = 1
|
||||
new_agent = create_test_agent(
|
||||
user_id="resp_new_agent", name="新坐席", status="online"
|
||||
)
|
||||
db_session.add_all([old_agent, new_agent])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_resp_grab", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "resp_old_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "resp_new_agent", "新坐席")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
result = data["data"]
|
||||
# 接手后该会话属于当前坐席
|
||||
assert result["is_mine"] is True
|
||||
# 自己的会话不能再接手
|
||||
assert result["can_grab"] is False
|
||||
# 坐席姓名应为新坐席姓名
|
||||
assert result["assigned_agent_name"] == "新坐席"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 三、边界情况测试
|
||||
# =============================================================================
|
||||
|
||||
class TestGrabEdgeCases:
|
||||
"""测试接手功能的边界情况。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_when_agent_at_max_load_error_3005(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证满负荷坐席无法接手 → 3005。"""
|
||||
# 原坐席有1个会话
|
||||
old_agent = create_test_agent(
|
||||
user_id="max_old_agent", name="原坐席", status="online"
|
||||
)
|
||||
old_agent.current_load = 1
|
||||
|
||||
# 新坐席已满负荷
|
||||
full_agent = create_test_agent(
|
||||
user_id="full_agent", name="满负荷坐席", status="online"
|
||||
)
|
||||
full_agent.current_load = 5
|
||||
full_agent.max_load = 5
|
||||
|
||||
db_session.add_all([old_agent, full_agent])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_max_load", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "max_old_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "full_agent", "满负荷坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3005
|
||||
assert "满负荷" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_nonexistent_conversation_error_3003(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证接手不存在的会话 → 3003。"""
|
||||
agent = create_test_agent(
|
||||
user_id="grab_ghost_agent", name="幽灵坐席", status="online"
|
||||
)
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
fake_id = str(uuid.uuid4())
|
||||
headers = await login_agent(client, "grab_ghost_agent", "幽灵坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{fake_id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# SessionService._get_conversation 抛出 ERR_CONVERSATION_NOT_FOUND (3003)
|
||||
assert data["code"] == 3003
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_ai_handling_status_error_3013(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证不能接手 ai_handling 状态的会话 → 3013。"""
|
||||
other_agent = create_test_agent(
|
||||
user_id="ai_other", name="AI坐席", status="online"
|
||||
)
|
||||
grabber = create_test_agent(
|
||||
user_id="ai_grabber", name="接手坐席", status="online"
|
||||
)
|
||||
db_session.add_all([other_agent, grabber])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_ai_handling", status="ai_handling"
|
||||
)
|
||||
conv.assigned_agent_id = "ai_other"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "ai_grabber", "接手坐席")
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# ai_handling 不是 serving,应返回 3013
|
||||
assert data["code"] == 3013
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_without_auth_returns_unauthorized(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未登录时接手请求返回未授权错误。"""
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_no_auth", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "some_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 不带 Authorization 头
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002 # ERR_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grab_old_agent_load_never_goes_negative(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证原坐席 current_load 不会变为负数(源码有 if > 0 保护)。"""
|
||||
# 原坐席 current_load 为 0(异常数据场景)
|
||||
old_agent = create_test_agent(
|
||||
user_id="zero_load_old", name="零负荷原坐席", status="online"
|
||||
)
|
||||
old_agent.current_load = 0
|
||||
|
||||
new_agent = create_test_agent(
|
||||
user_id="zero_load_new", name="新坐席", status="online"
|
||||
)
|
||||
new_agent.current_load = 0
|
||||
|
||||
db_session.add_all([old_agent, new_agent])
|
||||
await db_session.flush()
|
||||
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_zero_load", status="serving"
|
||||
)
|
||||
conv.assigned_agent_id = "zero_load_old"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "zero_load_new", "新坐席")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/grab",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
|
||||
# 原坐席 load 应仍为 0(不会变为负数)
|
||||
stmt = select(Agent).where(Agent.user_id == "zero_load_old")
|
||||
result = await db_session.execute(stmt)
|
||||
refreshed_old = result.scalars().first()
|
||||
assert refreshed_old.current_load == 0
|
||||
|
||||
# 新坐席 load 应为 1
|
||||
stmt = select(Agent).where(Agent.user_id == "zero_load_new")
|
||||
result = await db_session.execute(stmt)
|
||||
refreshed_new = result.scalars().first()
|
||||
assert refreshed_new.current_load == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 四、会话列表 N+1 查询优化验证
|
||||
# =============================================================================
|
||||
|
||||
class TestConversationListN1Optimization:
|
||||
"""测试会话列表接口的 N+1 查询优化。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_query_agent_names(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证多个会话涉及多个坐席时,assigned_agent_name 全部正确返回。
|
||||
|
||||
这间接验证了 N+1 优化:所有坐席姓名通过一次 IN 查询获取。
|
||||
如果 N+1 没优化,此测试仍会通过,但此测试确保批量查询结果映射正确。
|
||||
"""
|
||||
# 创建3个坐席
|
||||
agents = [
|
||||
create_test_agent(user_id=f"batch_agent_{i}", name=f"坐席{i+1}")
|
||||
for i in range(3)
|
||||
]
|
||||
db_session.add_all(agents)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建3个会话,分别分配给不同坐席
|
||||
convs = [
|
||||
create_test_conversation(
|
||||
employee_id=f"emp_batch_{i}", status="serving"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
for i, conv in enumerate(convs):
|
||||
conv.assigned_agent_id = f"batch_agent_{i}"
|
||||
|
||||
db_session.add_all(convs)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "batch_agent_0", "坐席1")
|
||||
response = await client.get("/conversations", headers=headers)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
items = data["data"]["items"]
|
||||
item_map = {item["id"]: item for item in items}
|
||||
|
||||
# 验证所有坐席姓名正确
|
||||
for i, conv in enumerate(convs):
|
||||
item = item_map[str(conv.id)]
|
||||
assert item["assigned_agent_name"] == f"坐席{i+1}", \
|
||||
f"会话 {conv.id} 的坐席姓名应为 '坐席{i+1}',实际为 '{item['assigned_agent_name']}'"
|
||||
@@ -0,0 +1,237 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 会话状态流转测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 会话创建默认状态为 queued
|
||||
# 2. 坐席接单:queued → serving
|
||||
# 3. 结单:serving → resolved
|
||||
# 4. 重复接单处理
|
||||
# 5. 已结单会话的操作限制
|
||||
# 6. 置顶/取消置顶切换
|
||||
# 7. 代办/取消代办切换
|
||||
# 8. 会话列表过滤
|
||||
# =============================================================================
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.services.session_service import SessionService
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
class TestConversationStateFlow:
|
||||
"""测试会话状态流转。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_conversation_default_status_queued(self, db_session):
|
||||
"""验证新会话默认状态为 queued。"""
|
||||
conv = create_test_conversation()
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
assert conv.status == "queued"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assign_conversation_to_serving(self, db_session):
|
||||
"""验证坐席接单将会话状态改为 serving。"""
|
||||
conv = create_test_conversation(status="queued")
|
||||
agent = create_test_agent(user_id="agent001", name="坐席小王")
|
||||
db_session.add_all([conv, agent])
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.assign_agent(conv.id, "agent001")
|
||||
|
||||
assert result.status == "serving"
|
||||
assert result.assigned_agent_id == "agent001"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conversation(self, db_session):
|
||||
"""验证结单将会话状态改为 resolved。"""
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.resolve_conversation(conv.id)
|
||||
|
||||
assert result.status == "resolved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_queued_conversation_is_allowed(self, db_session):
|
||||
"""验证 queued 状态的会话可以直接结单(员工问题自行解决)。"""
|
||||
conv = create_test_conversation(status="queued")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
# queued → resolved 是合法的状态流转
|
||||
result = await session_service.resolve_conversation(conv.id)
|
||||
assert result.status == "resolved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cannot_resolve_already_resolved_conversation(self, db_session):
|
||||
"""验证已结单的会话不能再结单。"""
|
||||
conv = create_test_conversation(status="resolved")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
from app.utils.response import AppException
|
||||
with pytest.raises(AppException):
|
||||
await session_service.resolve_conversation(conv.id)
|
||||
|
||||
|
||||
class TestConversationToggle:
|
||||
"""测试会话标记切换。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_pin(self, db_session):
|
||||
"""验证置顶切换:未置顶→置顶。"""
|
||||
conv = create_test_conversation(is_pinned=False)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.toggle_pin(conv.id)
|
||||
assert result.is_pinned is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_pin_off(self, db_session):
|
||||
"""验证置顶切换:置顶→取消置顶。"""
|
||||
conv = create_test_conversation(is_pinned=True)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.toggle_pin(conv.id)
|
||||
assert result.is_pinned is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_todo(self, db_session):
|
||||
"""验证代办切换:未代办→代办。"""
|
||||
conv = create_test_conversation(is_todo=False)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.toggle_todo(conv.id)
|
||||
assert result.is_todo is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_toggle_todo_off(self, db_session):
|
||||
"""验证代办切换:代办→取消代办。"""
|
||||
conv = create_test_conversation(is_todo=True)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result = await session_service.toggle_todo(conv.id)
|
||||
assert result.is_todo is False
|
||||
|
||||
|
||||
class TestConversationList:
|
||||
"""测试会话列表查询。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_conversations(self, db_session):
|
||||
"""验证获取所有会话。"""
|
||||
convs = [
|
||||
create_test_conversation(employee_id=f"list_user_{i}", status="queued")
|
||||
for i in range(3)
|
||||
]
|
||||
db_session.add_all(convs)
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result, total = await session_service.get_conversations()
|
||||
|
||||
assert total >= 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_by_status(self, db_session):
|
||||
"""验证按状态过滤会话。"""
|
||||
db_session.add(create_test_conversation(employee_id="filter_queued", status="queued"))
|
||||
db_session.add(create_test_conversation(employee_id="filter_serving", status="serving"))
|
||||
await db_session.flush()
|
||||
|
||||
session_service = SessionService(db_session)
|
||||
result, total = await session_service.get_conversations(status="queued")
|
||||
|
||||
for conv in result:
|
||||
assert conv.status == "queued"
|
||||
|
||||
|
||||
class TestConversationAPI:
|
||||
"""测试会话管理 API 端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversations_endpoint(self, client, db_session, mock_redis):
|
||||
"""验证 GET /api/conversations 返回正确格式。"""
|
||||
conv = create_test_conversation(employee_id="api_list_user")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 登录坐席获取 token(/api/conversations 需要 get_current_agent 认证)
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "conv_list_agent", "name": "会话列表坐席"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
response = await client.get(
|
||||
"/conversations",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert "items" in data["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conversation_endpoint(self, client, db_session, mock_redis):
|
||||
"""验证 POST /api/conversations/{id}/resolve 结单。
|
||||
|
||||
权限:只有主责坐席(assigned_agent_id)才能结单。
|
||||
"""
|
||||
# 先创建坐席
|
||||
from app.models.agent import Agent as AgentModel
|
||||
agent = AgentModel(
|
||||
user_id="resolve_test_agent",
|
||||
name="结单测试坐席",
|
||||
status="online",
|
||||
current_load=0,
|
||||
max_load=5,
|
||||
)
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建分配给此坐席的会话
|
||||
conv = create_test_conversation(employee_id="api_resolve_user", status="serving")
|
||||
conv.assigned_agent_id = "resolve_test_agent"
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 登录坐席获取 token
|
||||
login_resp = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": "resolve_test_agent", "name": "结单测试坐席"},
|
||||
)
|
||||
token = login_resp.json()["data"]["token"]
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/resolve",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["status"] == "resolved"
|
||||
@@ -0,0 +1,925 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — H5 OAuth2 认证流程测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. OAuth2 授权 URL 接口(GET /api/h5/oauth/authorize)
|
||||
# 2. OAuth2 回调接口(POST /api/h5/oauth/callback)
|
||||
# 3. Token 验证依赖函数 _get_current_employee
|
||||
# 4. 获取当前员工信息(GET /api/h5/me)
|
||||
# 5. 向后兼容(X-Employee-Id 头降级)
|
||||
# 6. 错误处理(WecomService 失败、Redis 不可用)
|
||||
# =============================================================================
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import Base, get_db
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.funny_phrase import FunnyPhrase
|
||||
from tests.conftest import MockRedis, create_test_conversation, test_engine, test_session_factory
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 专用 fixtures:带 h5 API Redis mock 的测试客户端
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def h5_client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncClient:
|
||||
"""提供针对 H5 OAuth2 API 的异步测试客户端。
|
||||
|
||||
与 conftest.py 的 client fixture 类似,但额外 mock 了
|
||||
app.api.h5 模块中的 _get_redis,确保 OAuth2 流程中
|
||||
Redis 操作使用内存模拟。
|
||||
"""
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
from app.main import create_app
|
||||
|
||||
app = create_app()
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
with patch("app.api.h5._get_redis", return_value=mock_redis):
|
||||
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_fresh() -> MockRedis:
|
||||
"""提供干净的模拟 Redis(每个测试独立)。"""
|
||||
return MockRedis()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. OAuth2 授权 URL 接口
|
||||
# ===========================================================================
|
||||
|
||||
class TestOAuthAuthorizeURL:
|
||||
"""测试 GET /api/h5/oauth/authorize — 获取企微 OAuth2 授权 URL。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_returns_correct_structure(self, h5_client):
|
||||
"""验证返回结构包含 authorize_url 字段。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert "authorize_url" in data["data"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_contains_correct_base(self, h5_client):
|
||||
"""验证授权 URL 以企微 OAuth2 基础地址开头。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
assert url.startswith("https://open.weixin.qq.com/connect/oauth2/authorize")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_contains_appid(self, h5_client):
|
||||
"""验证授权 URL 包含 appid 参数(企微 CorpID)。"""
|
||||
from app.config import settings
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
# corp_id 来自实际配置(可能是 .env 覆盖后的值)
|
||||
assert f"appid={settings.wecom_corp_id}" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_contains_scope_snsapi_base(self, h5_client):
|
||||
"""验证授权 URL 使用 snsapi_base 作用域(静默授权)。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
assert "scope=snsapi_base" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_contains_response_type_code(self, h5_client):
|
||||
"""验证授权 URL 包含 response_type=code。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
assert "response_type=code" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_contains_wechat_redirect(self, h5_client):
|
||||
"""验证授权 URL 末尾包含 #wechat_redirect。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
assert url.endswith("#wechat_redirect")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_with_redirect_uri_param(self, h5_client):
|
||||
"""验证传入 redirect_uri 参数时 URL 包含自定义回调地址。"""
|
||||
custom_uri = "https://myapp.example.com/h5/"
|
||||
response = await h5_client.get(
|
||||
"/h5/oauth/authorize",
|
||||
params={"redirect_uri": custom_uri},
|
||||
)
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
# redirect_uri 需要经过 URL 编码
|
||||
from urllib.parse import quote
|
||||
encoded = quote(custom_uri, safe="")
|
||||
assert f"redirect_uri={encoded}" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_with_host_header(self, h5_client):
|
||||
"""验证使用 Host 头构造默认回调地址。"""
|
||||
response = await h5_client.get(
|
||||
"/h5/oauth/authorize",
|
||||
headers={"Host": "myapp.example.com"},
|
||||
)
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
# Host 头构造的 URL 应使用 https 协议
|
||||
from urllib.parse import quote
|
||||
expected_redirect = quote("https://myapp.example.com/h5/", safe="")
|
||||
assert f"redirect_uri={expected_redirect}" in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorize_url_without_redirect_uri_uses_default(self, h5_client):
|
||||
"""验证不带 redirect_uri 且无 Host 头时使用配置默认值。"""
|
||||
response = await h5_client.get("/h5/oauth/authorize")
|
||||
data = response.json()
|
||||
url = data["data"]["authorize_url"]
|
||||
# 应该仍然返回有效的 URL(使用默认 origin)
|
||||
assert "redirect_uri=" in url
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. OAuth2 回调接口
|
||||
# ===========================================================================
|
||||
|
||||
class TestOAuthCallback:
|
||||
"""测试 POST /api/h5/oauth/callback — OAuth2 回调处理。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_returns_token_and_employee_info(self, h5_client, mock_redis):
|
||||
"""验证 OAuth2 回调返回 token 和员工信息。"""
|
||||
# Mock WecomService
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "test_user_001", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "测试员工",
|
||||
"department": [1, 2],
|
||||
"position": "工程师",
|
||||
"avatar": "https://avatar.example.com/test.jpg",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_auth_code"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
# 验证返回字段
|
||||
assert "token" in data["data"]
|
||||
assert data["data"]["employee_id"] == "test_user_001"
|
||||
assert data["data"]["employee_name"] == "测试员工"
|
||||
assert data["data"]["department"] == "1,2"
|
||||
assert data["data"]["position"] == "工程师"
|
||||
assert data["data"]["avatar"] == "https://avatar.example.com/test.jpg"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_stores_token_in_redis(self, h5_client, mock_redis):
|
||||
"""验证 token 存入 Redis,key 格式为 employee:token:{token}。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "redis_test_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "Redis测试",
|
||||
"department": [],
|
||||
"position": "",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_auth_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
|
||||
# 验证 Redis 中存在对应的 key
|
||||
stored = await mock_redis.get(f"employee:token:{token}")
|
||||
assert stored is not None
|
||||
assert stored == b"redis_test_user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_caches_employee_info_in_redis(self, h5_client, mock_redis):
|
||||
"""验证员工信息缓存到 Redis。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "cache_test_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "缓存测试",
|
||||
"department": [3],
|
||||
"position": "经理",
|
||||
"avatar": "https://avatar.example.com/cache.jpg",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_auth_code"},
|
||||
)
|
||||
|
||||
# 验证 Redis 中存在员工信息缓存
|
||||
cached = await mock_redis.get("employee:info:cache_test_user")
|
||||
assert cached is not None
|
||||
cached_info = json.loads(cached)
|
||||
assert cached_info["employee_id"] == "cache_test_user"
|
||||
assert cached_info["employee_name"] == "缓存测试"
|
||||
assert cached_info["department"] == "3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_with_empty_userid_returns_error(self, h5_client, mock_redis):
|
||||
"""验证 OAuth2 返回空 UserID 时报错。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "", "user_ticket": ""})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "bad_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] != 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_wecom_service_failure(self, h5_client, mock_redis):
|
||||
"""验证 WecomService 调用失败时的错误处理。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(side_effect=Exception("企微API不可用"))
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "will_fail"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] != 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_detail_fetch_failure_still_returns_token(self, h5_client, mock_redis):
|
||||
"""验证获取员工详细信息失败时仍返回 token(降级处理)。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "degrade_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(side_effect=Exception("通讯录API失败"))
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 应该仍然返回成功,token 和 employee_id
|
||||
assert data["code"] == 0
|
||||
assert "token" in data["data"]
|
||||
assert data["data"]["employee_id"] == "degrade_user"
|
||||
# 详细信息为空降级
|
||||
assert data["data"]["employee_name"] == ""
|
||||
assert data["data"]["department"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_missing_code_field(self, h5_client, mock_redis):
|
||||
"""验证缺少 code 字段时返回参数错误。"""
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={},
|
||||
)
|
||||
# Pydantic 验证失败
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_callback_empty_code_field(self, h5_client, mock_redis):
|
||||
"""验证空 code 字段时返回参数错误。"""
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": ""},
|
||||
)
|
||||
# Pydantic min_length=1 验证失败
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Token 验证依赖函数 _get_current_employee
|
||||
# ===========================================================================
|
||||
|
||||
class TestGetCurrentEmployee:
|
||||
"""测试 _get_current_employee 依赖注入函数。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_bearer_token(self, h5_client, mock_redis):
|
||||
"""验证有效 Bearer token 返回对应 employee_id。"""
|
||||
# 预设 Redis 中的 token 和员工信息缓存
|
||||
await mock_redis.setex("employee:token:test_valid_token", 28800, "authed_user_001")
|
||||
employee_info = {
|
||||
"employee_id": "authed_user_001",
|
||||
"employee_name": "认证测试用户",
|
||||
"department": "IT部",
|
||||
"position": "工程师",
|
||||
"mobile": "",
|
||||
"email": "",
|
||||
"avatar": "",
|
||||
}
|
||||
await mock_redis.setex(
|
||||
"employee:info:authed_user_001",
|
||||
28800,
|
||||
json.dumps(employee_info, ensure_ascii=False),
|
||||
)
|
||||
|
||||
# 调用需要认证的 /api/h5/me 接口
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer test_valid_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# 接口成功返回,说明认证通过
|
||||
assert data["code"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_token_returns_unauthorized(self, h5_client, mock_redis):
|
||||
"""验证无效 token 返回 401(业务码 1002)。"""
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer non_existent_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
assert "未授权" in data["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_authorization_header(self, h5_client, mock_redis):
|
||||
"""验证缺少 Authorization 头返回未授权。"""
|
||||
response = await h5_client.get("/h5/me")
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_authorization_header(self, h5_client, mock_redis):
|
||||
"""验证空的 Authorization 头返回未授权。"""
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": ""},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bearer_prefix_extraction(self, h5_client, mock_redis):
|
||||
"""验证 Bearer 前缀正确提取 token。"""
|
||||
# 设置 Redis token 和员工信息缓存
|
||||
await mock_redis.setex("employee:token:my_token_123", 28800, "prefix_test_user")
|
||||
employee_info = {
|
||||
"employee_id": "prefix_test_user",
|
||||
"employee_name": "前缀测试",
|
||||
"department": "",
|
||||
"position": "",
|
||||
"mobile": "",
|
||||
"email": "",
|
||||
"avatar": "",
|
||||
}
|
||||
await mock_redis.setex(
|
||||
"employee:info:prefix_test_user",
|
||||
28800,
|
||||
json.dumps(employee_info, ensure_ascii=False),
|
||||
)
|
||||
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer my_token_123"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 认证通过,接口返回成功
|
||||
assert data["code"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_without_bearer_prefix(self, h5_client, mock_redis):
|
||||
"""验证不带 Bearer 前缀的 token 也能被识别(兼容)。"""
|
||||
await mock_redis.setex("employee:token:raw_token_456", 28800, "raw_token_user")
|
||||
employee_info = {
|
||||
"employee_id": "raw_token_user",
|
||||
"employee_name": "原始Token测试",
|
||||
"department": "",
|
||||
"position": "",
|
||||
"mobile": "",
|
||||
"email": "",
|
||||
"avatar": "",
|
||||
}
|
||||
await mock_redis.setex(
|
||||
"employee:info:raw_token_user",
|
||||
28800,
|
||||
json.dumps(employee_info, ensure_ascii=False),
|
||||
)
|
||||
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "raw_token_456"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 源码中:如果 token 不以 "Bearer " 开头,直接使用整个值
|
||||
assert data["code"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_token_returns_unauthorized(self, h5_client, mock_redis):
|
||||
"""验证过期 token(Redis 中不存在)返回未授权。"""
|
||||
# 不在 Redis 中设置任何 token,模拟过期
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer expired_token_xyz"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. GET /api/h5/me 接口
|
||||
# ===========================================================================
|
||||
|
||||
class TestGetCurrentEmployeeInfo:
|
||||
"""测试 GET /api/h5/me — 获取当前员工详细信息。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_returns_employee_info_from_cache(self, h5_client, mock_redis):
|
||||
"""验证从 Redis 缓存读取员工信息。"""
|
||||
# 预设 token 和缓存信息
|
||||
await mock_redis.setex("employee:token:cache_me_token", 28800, "me_cache_user")
|
||||
employee_info = {
|
||||
"employee_id": "me_cache_user",
|
||||
"employee_name": "缓存用户",
|
||||
"department": "技术部",
|
||||
"position": "开发",
|
||||
"mobile": "13800138000",
|
||||
"email": "cache@test.com",
|
||||
"avatar": "https://avatar.example.com/me.jpg",
|
||||
}
|
||||
await mock_redis.setex(
|
||||
"employee:info:me_cache_user",
|
||||
28800,
|
||||
json.dumps(employee_info, ensure_ascii=False),
|
||||
)
|
||||
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer cache_me_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["employee_id"] == "me_cache_user"
|
||||
assert data["data"]["employee_name"] == "缓存用户"
|
||||
# is_vip 由接口补充
|
||||
assert data["data"]["is_vip"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_falls_back_to_wecom_api(self, h5_client, mock_redis):
|
||||
"""验证缓存不存在时从企微 API 获取员工信息。"""
|
||||
# 预设 token 但不设缓存
|
||||
await mock_redis.setex("employee:token:nocache_me_token", 28800, "me_nocache_user")
|
||||
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "API用户",
|
||||
"department": [5],
|
||||
"position": "测试",
|
||||
"avatar": "https://avatar.example.com/api.jpg",
|
||||
"mobile": "13900139000",
|
||||
"email": "api@test.com",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer nocache_me_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["employee_id"] == "me_nocache_user"
|
||||
assert data["data"]["employee_name"] == "API用户"
|
||||
assert data["data"]["department"] == "5"
|
||||
assert data["data"]["mobile"] == "13900139000"
|
||||
assert data["data"]["is_vip"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_unauthenticated_returns_401(self, h5_client, mock_redis):
|
||||
"""验证未认证时 /me 返回 401。"""
|
||||
response = await h5_client.get("/h5/me")
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. 向后兼容
|
||||
# ===========================================================================
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""测试向后兼容:X-Employee-Id 头降级模式。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_x_employee_id_header_still_works_for_old_endpoints(self, h5_client, db_session, mock_redis):
|
||||
"""验证旧版 X-Employee-Id 头仍可用于兼容旧接口。
|
||||
|
||||
注意:新接口(/h5/me, /h5/oauth/*)使用 Bearer Token,
|
||||
但旧端点(如 /h5/conversations/current)使用 _get_current_employee
|
||||
也支持旧方式需要看具体端点实现。此处验证旧方式在
|
||||
_get_employee_id 中仍然工作。
|
||||
"""
|
||||
# /h5/user 接口使用 _get_current_employee(需要 Bearer Token)
|
||||
# 但 /h5/conversations/current 也用 _get_current_employee
|
||||
# 旧版 _get_employee_id 只在特定端点使用
|
||||
|
||||
# 测试通过 Bearer Token 方式访问 /h5/conversations/current
|
||||
await mock_redis.setex("employee:token:compat_token", 28800, "compat_user")
|
||||
|
||||
# 先创建一个会话
|
||||
conv = create_test_conversation(
|
||||
employee_id="compat_user",
|
||||
status="queued",
|
||||
)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
response = await h5_client.get(
|
||||
"/h5/conversations/current",
|
||||
headers={"Authorization": "Bearer compat_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_old_x_employee_id_header_not_accepted_by_new_auth(self, h5_client, mock_redis):
|
||||
"""验证仅用 X-Employee-Id 头(无 Bearer Token)访问新接口返回未授权。
|
||||
|
||||
新的 _get_current_employee 只认 Bearer Token,
|
||||
不认 X-Employee-Id。这是正确的安全行为。
|
||||
"""
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"X-Employee-Id": "old_style_user"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 新接口只认 Bearer Token,X-Employee-Id 不应通过认证
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. 错误处理
|
||||
# ===========================================================================
|
||||
|
||||
class TestErrorHandling:
|
||||
"""测试错误处理场景。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_unavailable_during_token_validation(self, h5_client, mock_redis):
|
||||
"""验证 Redis 不可用时 token 验证降级返回未授权。"""
|
||||
# 模拟 Redis get 抛出异常
|
||||
original_get = mock_redis.get
|
||||
|
||||
async def broken_get(key):
|
||||
raise Exception("Redis connection refused")
|
||||
|
||||
mock_redis.get = broken_get
|
||||
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer some_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# Redis 不可用时应返回未授权
|
||||
assert data["code"] == 1002
|
||||
|
||||
# 恢复
|
||||
mock_redis.get = original_get
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_write_failure_during_callback(self, h5_client, mock_redis):
|
||||
"""验证 Redis 写入失败时 OAuth2 回调仍能完成(降级处理)。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "redis_fail_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "Redis故障测试",
|
||||
"department": [],
|
||||
"position": "",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
# Mock Redis setex to fail
|
||||
original_setex = mock_redis.setex
|
||||
|
||||
async def broken_setex(name, time, value):
|
||||
raise Exception("Redis write failed")
|
||||
|
||||
mock_redis.setex = broken_setex
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# Redis 写入失败不应阻塞 OAuth2 回调流程
|
||||
# token 仍然返回(虽然不会被持久化)
|
||||
assert data["code"] == 0
|
||||
assert "token" in data["data"]
|
||||
assert data["data"]["employee_id"] == "redis_fail_user"
|
||||
|
||||
# 恢复
|
||||
mock_redis.setex = original_setex
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wecom_oauth_failure_returns_error(self, h5_client, mock_redis):
|
||||
"""验证企微 OAuth2 服务失败时返回错误。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(
|
||||
side_effect=Exception("企微API超时")
|
||||
)
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "timeout_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] != 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_wecom_api_failure(self, h5_client, mock_redis):
|
||||
"""验证 /me 接口企微 API 失败时返回错误。"""
|
||||
# 预设 token 但不设缓存
|
||||
await mock_redis.setex("employee:token:wecom_fail_token", 28800, "wecom_fail_user")
|
||||
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_user_info = AsyncMock(
|
||||
side_effect=Exception("通讯录API失败")
|
||||
)
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": "Bearer wecom_fail_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
# 缓存不存在 + 企微API失败,应返回错误
|
||||
assert data["code"] != 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7. Token TTL 与格式
|
||||
# ===========================================================================
|
||||
|
||||
class TestTokenTTLAndFormat:
|
||||
"""测试 Token TTL 和格式。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_stored_with_correct_ttl(self, h5_client, mock_redis):
|
||||
"""验证 Token 存入 Redis 时设置了正确的 TTL(8小时=28800秒)。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "ttl_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "TTL测试",
|
||||
"department": [],
|
||||
"position": "",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "ttl_test_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
|
||||
# 验证 TTL
|
||||
ttl = mock_redis._ttl.get(f"employee:token:{token}")
|
||||
assert ttl == 28800 # 8小时 = 28800 秒
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_employee_info_cache_has_same_ttl(self, h5_client, mock_redis):
|
||||
"""验证员工信息缓存与 Token 使用相同的 TTL。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "info_ttl_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "InfoTTL测试",
|
||||
"department": [],
|
||||
"position": "",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "info_ttl_code"},
|
||||
)
|
||||
|
||||
# 验证员工信息缓存的 TTL
|
||||
info_ttl = mock_redis._ttl.get("employee:info:info_ttl_user")
|
||||
assert info_ttl == 28800
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_is_urlsafe(self, h5_client, mock_redis):
|
||||
"""验证生成的 Token 是 URL-safe 格式(secrets.token_urlsafe)。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "fmt_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "格式测试",
|
||||
"department": [],
|
||||
"position": "",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "fmt_test_code"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
# token 应该是非空字符串
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 0
|
||||
# URL-safe base64 字符集:A-Z, a-z, 0-9, -, _
|
||||
import re
|
||||
assert re.match(r'^[A-Za-z0-9_-]+$', token), f"Token '{token}' is not URL-safe"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 8. Schema 验证
|
||||
# ===========================================================================
|
||||
|
||||
class TestSchemaValidation:
|
||||
"""测试 Pydantic Schema 验证。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_request_requires_code(self, h5_client, mock_redis):
|
||||
"""验证 OAuthCallbackRequest 必须包含 code 字段。"""
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_request_code_min_length(self, h5_client, mock_redis):
|
||||
"""验证 code 字段最小长度为 1。"""
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": ""},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_callback_request_valid_code(self, h5_client, mock_redis):
|
||||
"""验证有效的 code 字段格式被接受。"""
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={"userid": "schema_user", "user_ticket": ""})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={"name": "", "department": [], "position": "", "avatar": ""})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "valid_code_here"},
|
||||
)
|
||||
# 请求格式正确,应返回 200(非 422)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 9. 端到端 OAuth2 流程
|
||||
# ===========================================================================
|
||||
|
||||
class TestOAuth2EndToEnd:
|
||||
"""测试完整的 OAuth2 认证流程。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_oauth2_flow(self, h5_client, mock_redis):
|
||||
"""验证完整 OAuth2 流程:获取授权URL → 回调获取token → 用token访问/me。"""
|
||||
# Step 1: 获取授权 URL
|
||||
auth_response = await h5_client.get("/h5/oauth/authorize")
|
||||
assert auth_response.json()["code"] == 0
|
||||
auth_url = auth_response.json()["data"]["authorize_url"]
|
||||
assert "snsapi_base" in auth_url
|
||||
|
||||
# Step 2: 模拟回调获取 token
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={
|
||||
"userid": "e2e_user",
|
||||
"user_ticket": "",
|
||||
})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "端到端用户",
|
||||
"department": [1, 2],
|
||||
"position": "架构师",
|
||||
"avatar": "https://avatar.example.com/e2e.jpg",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
callback_response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "e2e_auth_code"},
|
||||
)
|
||||
|
||||
callback_data = callback_response.json()
|
||||
assert callback_data["code"] == 0
|
||||
token = callback_data["data"]["token"]
|
||||
assert token # token 非空
|
||||
|
||||
# Step 3: 使用 token 访问 /me
|
||||
me_response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
me_data = me_response.json()
|
||||
assert me_data["code"] == 0
|
||||
assert me_data["data"]["employee_id"] == "e2e_user"
|
||||
assert me_data["data"]["employee_name"] == "端到端用户"
|
||||
assert me_data["data"]["is_vip"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_with_cached_info(self, h5_client, mock_redis):
|
||||
"""验证 OAuth2 流程完成后,后续 /me 请求从缓存读取。"""
|
||||
# Step 1: 模拟回调
|
||||
mock_wecom = AsyncMock()
|
||||
mock_wecom.get_oauth_user_info = AsyncMock(return_value={
|
||||
"userid": "cached_flow_user",
|
||||
"user_ticket": "",
|
||||
})
|
||||
mock_wecom.get_user_info = AsyncMock(return_value={
|
||||
"name": "缓存流程用户",
|
||||
"department": [10],
|
||||
"position": "产品",
|
||||
"avatar": "",
|
||||
})
|
||||
mock_wecom.close = AsyncMock()
|
||||
|
||||
with patch("app.api.h5.WecomService", return_value=mock_wecom):
|
||||
callback_response = await h5_client.post(
|
||||
"/h5/oauth/callback",
|
||||
json={"code": "cached_flow_code"},
|
||||
)
|
||||
|
||||
token = callback_response.json()["data"]["token"]
|
||||
|
||||
# Step 2: 第一次访问 /me(应从缓存读取,不再调用 WecomService)
|
||||
with patch("app.api.h5.WecomService") as MockWecomClass:
|
||||
me_response = await h5_client.get(
|
||||
"/h5/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
# WecomService 不应被实例化(因为缓存命中)
|
||||
MockWecomClass.assert_not_called()
|
||||
|
||||
me_data = me_response.json()
|
||||
assert me_data["code"] == 0
|
||||
assert me_data["data"]["employee_name"] == "缓存流程用户"
|
||||
@@ -0,0 +1,218 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — H5 摇人功能测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 摇人成功(新建会话 + 举手标记 + 趣味话术返回)
|
||||
# 2. 摇人成功(已有会话 + 更新举手标记)
|
||||
# 3. 缺少 employee_id 请求失败
|
||||
# 4. 获取当前会话
|
||||
# 5. H5 发送消息
|
||||
# 6. 审批链接获取
|
||||
# 7. 软件下载列表获取
|
||||
# =============================================================================
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.funny_phrase import FunnyPhrase
|
||||
from tests.conftest import create_test_conversation, MockRedis
|
||||
|
||||
|
||||
class TestShakeEndpoint:
|
||||
"""测试摇人 API 端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shake_creates_new_conversation(self, client, db_session):
|
||||
"""验证摇人时如果没有活跃会话则创建新会话。"""
|
||||
# 先添加趣味话术
|
||||
phrase = FunnyPhrase(scene="shake", content="摇人话术测试", tone="亲切", sort_order=1)
|
||||
db_session.add(phrase)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(
|
||||
"/h5/conversations/current/shake",
|
||||
json={"employee_id": "shake_new_user", "employee_name": "测试员工"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["conversation"]["tags"]["hand_raise"] is True
|
||||
assert data["data"]["funny_phrase"] != ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shake_updates_existing_conversation(self, client, db_session):
|
||||
"""验证摇人时如果已有活跃会话则更新举手标记。"""
|
||||
conv = create_test_conversation(
|
||||
employee_id="shake_existing_user",
|
||||
status="queued",
|
||||
tags={},
|
||||
)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 添加话术
|
||||
phrase = FunnyPhrase(scene="shake", content="更新摇人话术", tone="亲切", sort_order=1)
|
||||
db_session.add(phrase)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(
|
||||
"/h5/conversations/current/shake",
|
||||
json={"employee_id": "shake_existing_user", "employee_name": "已有用户"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["conversation"]["tags"]["hand_raise"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shake_returns_funny_phrase(self, client, db_session):
|
||||
"""验证摇人返回趣味话术。"""
|
||||
phrase = FunnyPhrase(scene="shake", content="测试趣味话术内容", tone="亲切", sort_order=1)
|
||||
db_session.add(phrase)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(
|
||||
"/h5/conversations/current/shake",
|
||||
json={"employee_id": "phrase_test_user", "employee_name": "话术测试"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["data"]["funny_phrase"] != ""
|
||||
|
||||
|
||||
class TestH5CurrentConversation:
|
||||
"""测试 H5 获取当前会话。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_conversation_exists(self, client, db_session, mock_redis):
|
||||
"""验证获取当前活跃会话。"""
|
||||
# 预设 Bearer Token(替代旧的 X-Employee-Id 头)
|
||||
await mock_redis.setex("employee:token:h5_current_token", 28800, "h5_current_user")
|
||||
|
||||
conv = create_test_conversation(employee_id="h5_current_user", status="queued")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.get(
|
||||
"/h5/conversations/current",
|
||||
headers={"Authorization": "Bearer h5_current_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_conversation_not_found(self, client, db_session, mock_redis):
|
||||
"""验证无活跃会话时返回空数据。"""
|
||||
# 预设 Bearer Token
|
||||
await mock_redis.setex("employee:token:no_conv_token", 28800, "no_conversation_user")
|
||||
|
||||
response = await client.get(
|
||||
"/h5/conversations/current",
|
||||
headers={"Authorization": "Bearer no_conv_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_conversation_no_employee_id(self, client, db_session):
|
||||
"""验证缺少员工ID时返回未授权错误。"""
|
||||
response = await client.get("/h5/conversations/current")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] != 0 # 应返回错误码
|
||||
|
||||
|
||||
class TestH5SendMessage:
|
||||
"""测试 H5 发送消息。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_creates_conversation(self, client, db_session, mock_redis):
|
||||
"""验证发送消息时自动创建会话。"""
|
||||
# 预设 Bearer Token
|
||||
await mock_redis.setex("employee:token:h5_msg_token", 28800, "h5_msg_user")
|
||||
|
||||
response = await client.post(
|
||||
"/h5/conversations/current/messages",
|
||||
json={"content": "VPN连不上了"},
|
||||
headers={"Authorization": "Bearer h5_msg_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["content"] == "VPN连不上了"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_empty_content(self, client, db_session, mock_redis):
|
||||
"""验证空消息内容返回错误。"""
|
||||
# 预设 Bearer Token
|
||||
await mock_redis.setex("employee:token:empty_msg_token", 28800, "empty_msg_user")
|
||||
|
||||
response = await client.post(
|
||||
"/h5/conversations/current/messages",
|
||||
json={"content": ""},
|
||||
headers={"Authorization": "Bearer empty_msg_token"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] != 0
|
||||
|
||||
|
||||
class TestApprovalLinks:
|
||||
"""测试审批链接获取。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_approval_links(self, client, seeded_db):
|
||||
"""验证获取审批链接列表。"""
|
||||
response = await client.get("/h5/approval-links")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert len(data["data"]["items"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_approval_links_by_category(self, client, seeded_db):
|
||||
"""验证按分类过滤审批链接。"""
|
||||
response = await client.get("/h5/approval-links?category=IT")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
for item in data["data"]["items"]:
|
||||
assert item["category"] == "IT"
|
||||
|
||||
|
||||
class TestSoftwareDownloads:
|
||||
"""测试软件下载列表。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_software_downloads(self, client, seeded_db):
|
||||
"""验证获取软件下载列表。"""
|
||||
response = await client.get("/h5/software-downloads")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert len(data["data"]["items"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_software_downloads_by_category(self, client, seeded_db):
|
||||
"""验证按分类过滤软件下载。"""
|
||||
response = await client.get("/h5/software-downloads?category=办公")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
for item in data["data"]["items"]:
|
||||
assert item["category"] == "办公"
|
||||
@@ -0,0 +1,793 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 邀请功能(Participant)单元测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 一、邀请参与者(POST /api/conversations/{id}/invite-participant)
|
||||
# 1. 成功邀请:participants 更新,系统消息创建
|
||||
# 2. 非主责坐席邀请 → 3030
|
||||
# 3. 非服务中会话邀请 → 3031
|
||||
# 4. 重复邀请(所有被邀请人已在) → 3032
|
||||
# 5. 邀请不存在的会话 → 3003
|
||||
# 6. 未认证邀请 → 401
|
||||
#
|
||||
# 二、加入会话(POST /api/conversations/{id}/join)
|
||||
# 1. 成功加入:joined 状态更新,系统消息创建
|
||||
# 2. 未被邀请者加入 → 3034
|
||||
# 3. 已结束会话加入 → 3033
|
||||
# 4. 不存在的会话加入 → 3003
|
||||
#
|
||||
# 三、移除参与者(DELETE /api/conversations/{id}/participants/{user_id})
|
||||
# 1. 成功移除:从 participants 中移除,系统消息创建
|
||||
# 2. 非主责坐席移除 → 3035
|
||||
# 3. 移除不在列表中的人员 → 3036
|
||||
# 4. 未认证移除 → 401
|
||||
#
|
||||
# 四、参与者退出(POST /api/conversations/{id}/leave-participant)
|
||||
# 1. 成功退出:从 participants 中移除,系统消息创建
|
||||
# 2. 非参与者退出 → 3037
|
||||
# 3. 不存在的会话退出 → 3003
|
||||
#
|
||||
# 五、端到端闭环
|
||||
# 1. 邀请 → 加入 → 退出(完整生命周期)
|
||||
# 2. 邀请 → 加入 → 坐席移除(管理员操作)
|
||||
# =============================================================================
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 辅助函数
|
||||
# =============================================================================
|
||||
|
||||
async def login_agent(client, user_id: str, name: str) -> dict:
|
||||
"""登录坐席并返回认证头字典。
|
||||
|
||||
做什么:调用登录 API 获取 token,组装 Authorization 头
|
||||
为什么:invite-participant 和 remove-participant 端点需要坐席认证
|
||||
|
||||
Args:
|
||||
client: httpx 异步测试客户端
|
||||
user_id: 坐席ID
|
||||
name: 坐席名称
|
||||
|
||||
Returns:
|
||||
dict: {"Authorization": "Bearer xxx"}
|
||||
"""
|
||||
response = await client.post(
|
||||
"/agents/login",
|
||||
json={"user_id": user_id, "name": name},
|
||||
)
|
||||
data = response.json()
|
||||
token = data["data"]["token"]
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
async def create_serving_conversation_with_participants(
|
||||
db_session: AsyncSession,
|
||||
employee_id: str = "emp_001",
|
||||
agent_user_id: str = "agent_owner",
|
||||
participants: list = None,
|
||||
) -> Conversation:
|
||||
"""创建一个 serving 状态且有主责坐席的会话(可选已有参与者)。
|
||||
|
||||
做什么:创建测试会话,设置 assigned_agent_id 和 participants
|
||||
为什么:邀请功能测试需要 serving 状态的会话作为前提
|
||||
|
||||
Args:
|
||||
db_session: 数据库会话
|
||||
employee_id: 员工ID
|
||||
agent_user_id: 主责坐席ID
|
||||
participants: 已有参与者列表
|
||||
|
||||
Returns:
|
||||
Conversation: 创建的会话对象
|
||||
"""
|
||||
conv = create_test_conversation(
|
||||
employee_id=employee_id,
|
||||
status="serving",
|
||||
)
|
||||
conv.assigned_agent_id = agent_user_id
|
||||
conv.participants = participants or []
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
return conv
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 一、邀请参与者测试
|
||||
# =============================================================================
|
||||
|
||||
class TestInviteParticipant:
|
||||
"""测试邀请参与者接口 POST /api/conversations/{id}/invite-participant。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_success_updates_participants(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功邀请:participants 列表更新,返回包含新参与者。
|
||||
|
||||
场景:主责坐席邀请2名员工加入会话。
|
||||
"""
|
||||
# 创建坐席
|
||||
owner = create_test_agent(user_id="owner_001", name="坐席A", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建 serving 会话,分配给 owner
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session, employee_id="emp_invite", agent_user_id="owner_001"
|
||||
)
|
||||
|
||||
# 坐席A 登录并发起邀请
|
||||
headers = await login_agent(client, "owner_001", "坐席A")
|
||||
|
||||
# Mock WebSocket 广播(避免真实 WS 连接)
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_zhang", "name": "张三", "department": "技术部", "type": "employee"},
|
||||
{"id": "emp_li", "name": "李四", "department": "财务部", "type": "employee"},
|
||||
],
|
||||
"history_mode": "recent10",
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 验证响应
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
participants = data["data"]["participants"]
|
||||
# 验证 participants 包含新添加的两人
|
||||
participant_ids = [p["id"] for p in participants]
|
||||
assert "emp_zhang" in participant_ids
|
||||
assert "emp_li" in participant_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_non_owner_agent_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证非主责坐席无法邀请 → 错误码 3030。
|
||||
|
||||
场景:协作坐席(非主责)尝试邀请他人。
|
||||
"""
|
||||
# 创建两个坐席
|
||||
owner = create_test_agent(user_id="owner_002", name="主责坐席", status="online")
|
||||
other = create_test_agent(user_id="other_002", name="其他坐席", status="online")
|
||||
db_session.add_all([owner, other])
|
||||
await db_session.flush()
|
||||
|
||||
# 创建会话,主责是 owner
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session, employee_id="emp_002", agent_user_id="owner_002"
|
||||
)
|
||||
|
||||
# 其他坐席登录并尝试邀请
|
||||
headers = await login_agent(client, "other_002", "其他坐席")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_wang", "name": "王五", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# 验证:返回错误码 3030(后端所有 AppException 返回 HTTP 200 + 业务错误码)
|
||||
data = response.json()
|
||||
assert data["code"] == 3030
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_non_serving_conversation_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证非服务中会话无法邀请 → 错误码 3031。
|
||||
|
||||
场景:对已结单(closed)的会话尝试邀请。
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_003", name="坐席C", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建 closed 状态的会话
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_003",
|
||||
status="closed",
|
||||
)
|
||||
conv.assigned_agent_id = "owner_003"
|
||||
conv.participants = []
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "owner_003", "坐席C")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_zhao", "name": "赵六", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3031
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_duplicate_participants_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证重复邀请同一批人 → 错误码 3032。
|
||||
|
||||
场景:参与者已在列表中,再次邀请相同的人。
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_004", name="坐席D", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
# 创建已有参与者的会话
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_004",
|
||||
agent_user_id="owner_004",
|
||||
participants=[
|
||||
{"id": "emp_dup", "name": "重复人", "department": "技术部", "type": "employee"},
|
||||
],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_004", "坐席D")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_dup", "name": "重复人", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3032
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_nonexistent_conversation(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证邀请不存在的会话 → 错误码 3003。"""
|
||||
owner = create_test_agent(user_id="owner_005", name="坐席E", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
headers = await login_agent(client, "owner_005", "坐席E")
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{fake_id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_x", "name": "某人", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3003
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_without_auth_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未认证邀请 → 错误码 1002。"""
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session, employee_id="emp_noauth", agent_user_id="owner_noauth"
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_y", "name": "某人", "type": "employee"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 二、加入会话测试
|
||||
# =============================================================================
|
||||
|
||||
class TestJoinConversation:
|
||||
"""测试加入会话接口 POST /api/conversations/{id}/join。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_success_updates_joined_status(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功加入:joined 状态更新为 True。
|
||||
|
||||
场景:被邀请员工通过链接加入会话。
|
||||
"""
|
||||
# 创建会话,已有被邀请但未加入的参与者
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_join_001",
|
||||
agent_user_id="agent_join",
|
||||
participants=[
|
||||
{"id": "emp_zhang", "name": "张三", "department": "技术部", "type": "employee", "joined": False},
|
||||
],
|
||||
)
|
||||
|
||||
# Mock WebSocket 广播
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/join",
|
||||
json={"employee_id": "emp_zhang"},
|
||||
)
|
||||
|
||||
# 验证响应
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
# 验证 joined 状态
|
||||
participants = data["data"]["participants"]
|
||||
zhang = next(p for p in participants if p["id"] == "emp_zhang")
|
||||
assert zhang["joined"] is True
|
||||
assert "joined_at" in zhang
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_not_invited_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未被邀请者无法加入 → 错误码 3034。
|
||||
|
||||
场景:未被邀请的员工尝试加入会话。
|
||||
"""
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_join_002",
|
||||
agent_user_id="agent_join_002",
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/join",
|
||||
json={"employee_id": "emp_hacker"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3034
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_closed_conversation_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证已结单会话无法加入 → 错误码 3033。
|
||||
|
||||
场景:被邀请人尝试加入已结束的会话。
|
||||
"""
|
||||
conv = create_test_conversation(
|
||||
employee_id="emp_join_003",
|
||||
status="closed",
|
||||
)
|
||||
conv.assigned_agent_id = "agent_join_003"
|
||||
conv.participants = [
|
||||
{"id": "emp_late", "name": "迟到者", "type": "employee", "joined": False},
|
||||
]
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/join",
|
||||
json={"employee_id": "emp_late"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3033
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_nonexistent_conversation(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证加入不存在的会话 → 错误码 3003。"""
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = await client.post(
|
||||
f"/conversations/{fake_id}/join",
|
||||
json={"employee_id": "emp_ghost"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3003
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 三、移除参与者测试
|
||||
# =============================================================================
|
||||
|
||||
class TestRemoveParticipant:
|
||||
"""测试移除参与者接口 DELETE /api/conversations/{id}/participants/{user_id}。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_success(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功移除:从 participants 列表中移除目标。
|
||||
|
||||
场景:主责坐席移除一名参与者。
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_rm", name="坐席RM", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_rm_001",
|
||||
agent_user_id="owner_rm",
|
||||
participants=[
|
||||
{"id": "emp_target", "name": "被移除人", "type": "employee", "joined": True},
|
||||
{"id": "emp_keep", "name": "保留人", "type": "employee", "joined": True},
|
||||
],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_rm", "坐席RM")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.delete(
|
||||
f"/conversations/{conv.id}/participants/emp_target",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
# 验证被移除人不在列表中
|
||||
participants = data["data"]["participants"]
|
||||
participant_ids = [p["id"] for p in participants]
|
||||
assert "emp_target" not in participant_ids
|
||||
assert "emp_keep" in participant_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_non_owner_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证非主责坐席无法移除 → 错误码 3035。"""
|
||||
owner = create_test_agent(user_id="owner_rm2", name="主责", status="online")
|
||||
other = create_test_agent(user_id="other_rm2", name="其他坐席", status="online")
|
||||
db_session.add_all([owner, other])
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_rm_002",
|
||||
agent_user_id="owner_rm2",
|
||||
participants=[
|
||||
{"id": "emp_victim", "name": "被移除人", "type": "employee"},
|
||||
],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "other_rm2", "其他坐席")
|
||||
|
||||
response = await client.delete(
|
||||
f"/conversations/{conv.id}/participants/emp_victim",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3035
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent_participant(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证移除不在列表中的人员 → 错误码 3036。"""
|
||||
owner = create_test_agent(user_id="owner_rm3", name="坐席", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_rm_003",
|
||||
agent_user_id="owner_rm3",
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_rm3", "坐席")
|
||||
|
||||
response = await client.delete(
|
||||
f"/conversations/{conv.id}/participants/emp_ghost",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3036
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_without_auth_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证未认证移除 → 错误码 1002。"""
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_rm_noauth",
|
||||
agent_user_id="owner_noauth",
|
||||
participants=[
|
||||
{"id": "emp_target", "name": "被移除人", "type": "employee"},
|
||||
],
|
||||
)
|
||||
|
||||
response = await client.delete(
|
||||
f"/conversations/{conv.id}/participants/emp_target",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 四、参与者退出测试
|
||||
# =============================================================================
|
||||
|
||||
class TestLeaveAsParticipant:
|
||||
"""测试参与者退出接口 POST /api/conversations/{id}/leave-participant。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_success(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证成功退出:从 participants 列表中移除自己。
|
||||
|
||||
场景:被邀请人主动退出会话。
|
||||
"""
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_leave_001",
|
||||
agent_user_id="agent_leave",
|
||||
participants=[
|
||||
{"id": "emp_leaver", "name": "退出者", "type": "employee", "joined": True},
|
||||
{"id": "emp_stayer", "name": "留守者", "type": "employee", "joined": True},
|
||||
],
|
||||
)
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave-participant",
|
||||
json={"employee_id": "emp_leaver"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
# 验证退出者不在列表中
|
||||
participants = data["data"]["participants"]
|
||||
participant_ids = [p["id"] for p in participants]
|
||||
assert "emp_leaver" not in participant_ids
|
||||
assert "emp_stayer" in participant_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_not_participant_rejected(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证非参与者退出 → 错误码 3037。
|
||||
|
||||
场景:未被邀请的人尝试退出会话。
|
||||
"""
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_leave_002",
|
||||
agent_user_id="agent_leave_002",
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/leave-participant",
|
||||
json={"employee_id": "emp_stranger"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3037
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_leave_nonexistent_conversation(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证退出不存在的会话 → 错误码 3003。"""
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = await client.post(
|
||||
f"/conversations/{fake_id}/leave-participant",
|
||||
json={"employee_id": "emp_ghost"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 3003
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 五、端到端闭环测试
|
||||
# =============================================================================
|
||||
|
||||
class TestInviteEndToEnd:
|
||||
"""邀请功能端到端闭环测试。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_invite_join_leave(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证完整生命周期:邀请 → 加入 → 退出。
|
||||
|
||||
场景:
|
||||
1. 坐席邀请张三
|
||||
2. 张三加入
|
||||
3. 张三退出
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_e2e", name="坐席E2E", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_e2e_001",
|
||||
agent_user_id="owner_e2e",
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_e2e", "坐席E2E")
|
||||
|
||||
# Step 1: 邀请
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
invite_resp = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_e2e_zhang", "name": "张三", "department": "技术部", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert invite_resp.status_code == 200
|
||||
invite_data = invite_resp.json()
|
||||
participants_after_invite = invite_data["data"]["participants"]
|
||||
assert any(p["id"] == "emp_e2e_zhang" for p in participants_after_invite)
|
||||
|
||||
# Step 2: 加入
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
join_resp = await client.post(
|
||||
f"/conversations/{conv.id}/join",
|
||||
json={"employee_id": "emp_e2e_zhang"},
|
||||
)
|
||||
|
||||
assert join_resp.status_code == 200
|
||||
join_data = join_resp.json()
|
||||
participants_after_join = join_data["data"]["participants"]
|
||||
zhang = next(p for p in participants_after_join if p["id"] == "emp_e2e_zhang")
|
||||
assert zhang["joined"] is True
|
||||
|
||||
# Step 3: 退出
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
leave_resp = await client.post(
|
||||
f"/conversations/{conv.id}/leave-participant",
|
||||
json={"employee_id": "emp_e2e_zhang"},
|
||||
)
|
||||
|
||||
assert leave_resp.status_code == 200
|
||||
leave_data = leave_resp.json()
|
||||
participants_after_leave = leave_data["data"]["participants"]
|
||||
assert not any(p["id"] == "emp_e2e_zhang" for p in participants_after_leave)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_invite_join_remove(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证完整生命周期:邀请 → 加入 → 坐席移除。
|
||||
|
||||
场景:
|
||||
1. 坐席邀请李四
|
||||
2. 李四加入
|
||||
3. 坐席移除李四
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_e2e2", name="坐席E2E2", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_e2e_002",
|
||||
agent_user_id="owner_e2e2",
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_e2e2", "坐席E2E2")
|
||||
|
||||
# Step 1: 邀请
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
invite_resp = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_e2e_li", "name": "李四", "department": "财务部", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert invite_resp.status_code == 200
|
||||
|
||||
# Step 2: 加入
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
join_resp = await client.post(
|
||||
f"/conversations/{conv.id}/join",
|
||||
json={"employee_id": "emp_e2e_li"},
|
||||
)
|
||||
|
||||
assert join_resp.status_code == 200
|
||||
|
||||
# Step 3: 坐席移除
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
remove_resp = await client.delete(
|
||||
f"/conversations/{conv.id}/participants/emp_e2e_li",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert remove_resp.status_code == 200
|
||||
remove_data = remove_resp.json()
|
||||
participants_after_remove = remove_data["data"]["participants"]
|
||||
assert not any(p["id"] == "emp_e2e_li" for p in participants_after_remove)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invite_partial_duplicate_merges(
|
||||
self, client, db_session, mock_redis
|
||||
):
|
||||
"""验证邀请部分新人 + 部分已在人员:只添加新人,忽略已有人。
|
||||
|
||||
场景:会话已有张三,再邀请张三和王五,只有王五被添加。
|
||||
"""
|
||||
owner = create_test_agent(user_id="owner_merge", name="坐席合并", status="online")
|
||||
db_session.add(owner)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await create_serving_conversation_with_participants(
|
||||
db_session,
|
||||
employee_id="emp_merge",
|
||||
agent_user_id="owner_merge",
|
||||
participants=[
|
||||
{"id": "emp_existing", "name": "已有张三", "department": "技术部", "type": "employee"},
|
||||
],
|
||||
)
|
||||
|
||||
headers = await login_agent(client, "owner_merge", "坐席合并")
|
||||
|
||||
with patch("app.services.ws_manager.manager.broadcast", new_callable=AsyncMock):
|
||||
response = await client.post(
|
||||
f"/conversations/{conv.id}/invite-participant",
|
||||
json={
|
||||
"participants": [
|
||||
{"id": "emp_existing", "name": "已有张三", "type": "employee"},
|
||||
{"id": "emp_new", "name": "新人王五", "department": "市场部", "type": "employee"},
|
||||
],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
participants = data["data"]["participants"]
|
||||
participant_ids = [p["id"] for p in participants]
|
||||
assert "emp_existing" in participant_ids # 已有的仍在
|
||||
assert "emp_new" in participant_ids # 新人被添加
|
||||
@@ -0,0 +1,543 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 消息去重功能测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. MsgId 重复消息被过滤
|
||||
# 2. 相同用户 + 内容重复被过滤
|
||||
# 3. 不同消息正常通过
|
||||
# 4. TTL 过期后消息可正常处理
|
||||
# 5. Redis 不可用时降级放行
|
||||
# 6. CacheService 独立方法测试
|
||||
# 7. MessageRouter 集成去重测试
|
||||
# =============================================================================
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.services.cache_service import (
|
||||
CacheService,
|
||||
MSG_DEDUP_PREFIX,
|
||||
CONTENT_DEDUP_PREFIX,
|
||||
DEFAULT_MSG_DEDUP_TTL,
|
||||
DEFAULT_CONTENT_DEDUP_TTL,
|
||||
)
|
||||
from app.services.message_router import MessageRouter
|
||||
from app.services.scoring_service import ScoringService
|
||||
from app.services.wecom_service import WecomService
|
||||
from tests.conftest import MockRedis, create_test_conversation
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_router_db(db_session):
|
||||
"""初始化路由器所需的数据库配置。"""
|
||||
configs = [
|
||||
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","真人"]'),
|
||||
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉"]'),
|
||||
SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上"]'),
|
||||
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕"]'),
|
||||
SystemConfig(config_key="intervene_round_threshold", config_value="3"),
|
||||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1"),
|
||||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_vip_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1"),
|
||||
]
|
||||
db_session.add_all(configs)
|
||||
await db_session.flush()
|
||||
|
||||
|
||||
def _create_mock_wecom_service():
|
||||
"""创建模拟的 WecomService。"""
|
||||
mock = AsyncMock(spec=WecomService)
|
||||
mock.get_user_info = AsyncMock(return_value={
|
||||
"name": "张三",
|
||||
"department": "[1, 2]",
|
||||
"position": "工程师",
|
||||
})
|
||||
mock.send_text_message = AsyncMock(return_value={"errcode": 0})
|
||||
mock.close = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_wecom_service():
|
||||
return _create_mock_wecom_service()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client():
|
||||
"""提供干净的 MockRedis 实例。"""
|
||||
return MockRedis()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_service(mock_redis_client):
|
||||
"""提供带 MockRedis 的 CacheService。"""
|
||||
return CacheService(mock_redis_client)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_service_no_redis():
|
||||
"""提供无 Redis 的 CacheService(降级模式)。"""
|
||||
return CacheService(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_with_dedup(db_session, mock_wecom_service, mock_redis_client, setup_router_db):
|
||||
"""创建带去重功能的消息路由器。"""
|
||||
scoring_service = ScoringService(db_session)
|
||||
cache_service = CacheService(mock_redis_client)
|
||||
return MessageRouter(
|
||||
db=db_session,
|
||||
wecom_service=mock_wecom_service,
|
||||
scoring_service=scoring_service,
|
||||
cache_service=cache_service,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router_no_dedup(db_session, mock_wecom_service, setup_router_db):
|
||||
"""创建无去重功能的消息路由器(cache_service=None)。"""
|
||||
scoring_service = ScoringService(db_session)
|
||||
return MessageRouter(
|
||||
db=db_session,
|
||||
wecom_service=mock_wecom_service,
|
||||
scoring_service=scoring_service,
|
||||
cache_service=None,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CacheService 独立测试
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCacheServiceIsDuplicate:
|
||||
"""测试 CacheService.is_duplicate() 方法。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_message_not_duplicate(self, cache_service, mock_redis_client):
|
||||
"""首次消息不应被判定为重复。"""
|
||||
result = await cache_service.is_duplicate("msg_001")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_msg_id_is_duplicate(self, cache_service, mock_redis_client):
|
||||
"""相同 MsgId 的第二次调用应被判定为重复。"""
|
||||
# 第一次:非重复
|
||||
result1 = await cache_service.is_duplicate("msg_002")
|
||||
assert result1 is False
|
||||
|
||||
# 第二次:重复
|
||||
result2 = await cache_service.is_duplicate("msg_002")
|
||||
assert result2 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_msg_id_not_duplicate(self, cache_service, mock_redis_client):
|
||||
"""不同 MsgId 不应互相影响。"""
|
||||
result1 = await cache_service.is_duplicate("msg_003")
|
||||
result2 = await cache_service.is_duplicate("msg_004")
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_msg_id_not_duplicate(self, cache_service):
|
||||
"""空 MsgId 应放行(不判断为重复)。"""
|
||||
result = await cache_service.is_duplicate("")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_redis_graceful_degradation(self, cache_service_no_redis):
|
||||
"""Redis 不可用时应降级放行(返回 False)。"""
|
||||
result = await cache_service_no_redis.is_duplicate("msg_005")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_key_format(self, cache_service, mock_redis_client):
|
||||
"""验证 Redis key 格式为 msg:dedup:{msg_id}。"""
|
||||
await cache_service.is_duplicate("msg_006")
|
||||
expected_key = f"{MSG_DEDUP_PREFIX}:msg_006"
|
||||
assert expected_key in mock_redis_client._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_key_ttl(self, cache_service, mock_redis_client):
|
||||
"""验证 Redis key 设置了正确的 TTL。"""
|
||||
await cache_service.is_duplicate("msg_007")
|
||||
expected_key = f"{MSG_DEDUP_PREFIX}:msg_007"
|
||||
assert mock_redis_client._ttl.get(expected_key) == DEFAULT_MSG_DEDUP_TTL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_ttl(self, cache_service, mock_redis_client):
|
||||
"""验证自定义 TTL 生效。"""
|
||||
custom_ttl = 600
|
||||
await cache_service.is_duplicate("msg_008", ttl=custom_ttl)
|
||||
expected_key = f"{MSG_DEDUP_PREFIX}:msg_008"
|
||||
assert mock_redis_client._ttl.get(expected_key) == custom_ttl
|
||||
|
||||
|
||||
class TestCacheServiceIsDuplicateContent:
|
||||
"""测试 CacheService.is_duplicate_content() 方法。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_message_not_duplicate(self, cache_service):
|
||||
"""首次消息不应被判定为内容重复。"""
|
||||
result = await cache_service.is_duplicate_content("user_001", "帮我重置密码")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_same_content_is_duplicate(self, cache_service):
|
||||
"""相同用户发送相同内容应被判定为重复。"""
|
||||
# 第一次:非重复
|
||||
result1 = await cache_service.is_duplicate_content("user_002", "VPN连不上")
|
||||
assert result1 is False
|
||||
|
||||
# 第二次:重复
|
||||
result2 = await cache_service.is_duplicate_content("user_002", "VPN连不上")
|
||||
assert result2 is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_different_content_not_duplicate(self, cache_service):
|
||||
"""相同用户发送不同内容不应被判定为重复。"""
|
||||
result1 = await cache_service.is_duplicate_content("user_003", "重置密码")
|
||||
result2 = await cache_service.is_duplicate_content("user_003", "安装软件")
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_user_same_content_not_duplicate(self, cache_service):
|
||||
"""不同用户发送相同内容不应被判定为重复。"""
|
||||
result1 = await cache_service.is_duplicate_content("user_004", "VPN连不上")
|
||||
result2 = await cache_service.is_duplicate_content("user_005", "VPN连不上")
|
||||
assert result1 is False
|
||||
assert result2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_user_id_not_duplicate(self, cache_service):
|
||||
"""空 user_id 应放行。"""
|
||||
result = await cache_service.is_duplicate_content("", "帮我重置密码")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_content_not_duplicate(self, cache_service):
|
||||
"""空 content 应放行。"""
|
||||
result = await cache_service.is_duplicate_content("user_006", "")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_redis_graceful_degradation(self, cache_service_no_redis):
|
||||
"""Redis 不可用时应降级放行。"""
|
||||
result = await cache_service_no_redis.is_duplicate_content("user_007", "VPN连不上")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_key_format(self, cache_service, mock_redis_client):
|
||||
"""验证 Redis key 包含用户ID和内容哈希。"""
|
||||
await cache_service.is_duplicate_content("user_008", "帮我重置密码")
|
||||
content_hash = hashlib.sha256("user_008:帮我重置密码".encode("utf-8")).hexdigest()[:16]
|
||||
expected_key = f"{CONTENT_DEDUP_PREFIX}:user_008:{content_hash}"
|
||||
assert expected_key in mock_redis_client._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_dedup_ttl(self, cache_service, mock_redis_client):
|
||||
"""验证内容去重的 TTL 默认为 60 秒。"""
|
||||
await cache_service.is_duplicate_content("user_009", "VPN连不上")
|
||||
content_hash = hashlib.sha256("user_009:VPN连不上".encode("utf-8")).hexdigest()[:16]
|
||||
expected_key = f"{CONTENT_DEDUP_PREFIX}:user_009:{content_hash}"
|
||||
assert mock_redis_client._ttl.get(expected_key) == DEFAULT_CONTENT_DEDUP_TTL
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TTL 过期测试
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTTLExpiry:
|
||||
"""测试 TTL 过期后消息可正常处理。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msg_id_dedup_key_expires(self, cache_service, mock_redis_client):
|
||||
"""验证 MsgId 去重 key 可通过手动删除模拟过期后放行。"""
|
||||
msg_id = "msg_expire_001"
|
||||
|
||||
# 首次:非重复
|
||||
result1 = await cache_service.is_duplicate(msg_id)
|
||||
assert result1 is False
|
||||
|
||||
# 重复
|
||||
result2 = await cache_service.is_duplicate(msg_id)
|
||||
assert result2 is True
|
||||
|
||||
# 模拟 TTL 过期:手动删除 key
|
||||
key = f"{MSG_DEDUP_PREFIX}:{msg_id}"
|
||||
await mock_redis_client.delete(key)
|
||||
|
||||
# 过期后:非重复
|
||||
result3 = await cache_service.is_duplicate(msg_id)
|
||||
assert result3 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_dedup_key_expires(self, cache_service, mock_redis_client):
|
||||
"""验证内容去重 key 过期后放行。"""
|
||||
user_id = "user_expire_001"
|
||||
content = "VPN连不上"
|
||||
|
||||
# 首次:非重复
|
||||
result1 = await cache_service.is_duplicate_content(user_id, content)
|
||||
assert result1 is False
|
||||
|
||||
# 重复
|
||||
result2 = await cache_service.is_duplicate_content(user_id, content)
|
||||
assert result2 is True
|
||||
|
||||
# 模拟 TTL 过期
|
||||
content_hash = hashlib.sha256(f"{user_id}:{content}".encode("utf-8")).hexdigest()[:16]
|
||||
key = f"{CONTENT_DEDUP_PREFIX}:{user_id}:{content_hash}"
|
||||
await mock_redis_client.delete(key)
|
||||
|
||||
# 过期后:非重复
|
||||
result3 = await cache_service.is_duplicate_content(user_id, content)
|
||||
assert result3 is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MessageRouter 集成去重测试
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestMessageRouterDedup:
|
||||
"""测试 MessageRouter 集成去重功能。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_msg_id_returns_none(self, router_with_dedup, mock_redis_client):
|
||||
"""相同 MsgId 的重复消息应返回 None(被过滤)。"""
|
||||
# 首次消息正常处理
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="dedup_user_001",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_dedup_001",
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# 相同 MsgId 再次调用,应被去重过滤
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="dedup_user_001",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_dedup_001",
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_content_returns_none(self, router_with_dedup, mock_redis_client):
|
||||
"""相同用户发送相同内容(不同 MsgId)应在 60 秒内被过滤。"""
|
||||
# 首次消息正常处理
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="dedup_user_002",
|
||||
content="VPN连不上",
|
||||
msg_id="msg_dedup_002a",
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# 不同 MsgId 但相同用户+内容,应被内容去重过滤
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="dedup_user_002",
|
||||
content="VPN连不上",
|
||||
msg_id="msg_dedup_002b",
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_messages_pass_through(self, router_with_dedup, mock_redis_client):
|
||||
"""不同消息应正常通过。"""
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="normal_user_001",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_normal_001",
|
||||
)
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="normal_user_001",
|
||||
content="安装Office",
|
||||
msg_id="msg_normal_002",
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result2 is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cache_service_skips_dedup(self, router_no_dedup):
|
||||
"""cache_service=None 时跳过去重检查,所有消息正常处理。"""
|
||||
# 两次相同 MsgId,但无去重 → 都正常处理
|
||||
result1 = await router_no_dedup.route_message(
|
||||
from_user_id="no_dedup_user",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_no_dedup_001",
|
||||
)
|
||||
result2 = await router_no_dedup.route_message(
|
||||
from_user_id="no_dedup_user",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_no_dedup_001",
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result2 is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_msg_id_skips_msg_id_dedup(self, router_with_dedup):
|
||||
"""msg_id=None 时跳过 MsgId 去重,但仍检查内容去重。"""
|
||||
# 第一次:无 msg_id,正常处理
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="no_msgid_user",
|
||||
content="WiFi连不上",
|
||||
msg_id=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# 第二次:无 msg_id,相同用户+内容 → 内容去重命中
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="no_msgid_user",
|
||||
content="WiFi连不上",
|
||||
msg_id=None,
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_users_same_content_passes(self, router_with_dedup):
|
||||
"""不同用户发送相同内容应正常通过(内容去重是用户维度的)。"""
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="user_a",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_user_a_001",
|
||||
)
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="user_b",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_user_b_001",
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result2 is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_expired_allows_reprocessing(
|
||||
self, router_with_dedup, mock_redis_client
|
||||
):
|
||||
"""TTL 过期后,相同消息可重新处理。"""
|
||||
# 首次:正常处理
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="expire_user",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_expire_001",
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# 重复:被过滤
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="expire_user",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_expire_001",
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
# 模拟 TTL 过期:删除 Redis key
|
||||
key = f"{MSG_DEDUP_PREFIX}:msg_expire_001"
|
||||
await mock_redis_client.delete(key)
|
||||
content_hash = hashlib.sha256("expire_user:帮我重置密码".encode("utf-8")).hexdigest()[:16]
|
||||
content_key = f"{CONTENT_DEDUP_PREFIX}:expire_user:{content_hash}"
|
||||
await mock_redis_client.delete(content_key)
|
||||
|
||||
# 过期后:可重新处理
|
||||
result3 = await router_with_dedup.route_message(
|
||||
from_user_id="expire_user",
|
||||
content="帮我重置密码",
|
||||
msg_id="msg_expire_001",
|
||||
)
|
||||
assert result3 is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_text_message_dedup(self, router_with_dedup, mock_redis_client):
|
||||
"""非文本消息也应当经过去重检查。"""
|
||||
# 首次:正常处理
|
||||
result1 = await router_with_dedup.route_message(
|
||||
from_user_id="nontext_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
msg_id="msg_nontext_001",
|
||||
media_id="media_123",
|
||||
)
|
||||
assert result1 is not None
|
||||
|
||||
# 重复:被过滤
|
||||
result2 = await router_with_dedup.route_message(
|
||||
from_user_id="nontext_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
msg_id="msg_nontext_001",
|
||||
media_id="media_123",
|
||||
)
|
||||
assert result2 is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CacheService 通用缓存测试
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCacheServiceGeneral:
|
||||
"""测试 CacheService 通用缓存操作。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_key(self, cache_service, mock_redis_client):
|
||||
"""获取已存在的 key。"""
|
||||
await cache_service.set("test_key", "test_value")
|
||||
result = await cache_service.get("test_key")
|
||||
assert result == "test_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key(self, cache_service):
|
||||
"""获取不存在的 key 返回 None。"""
|
||||
result = await cache_service.get("nonexistent_key")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_with_ttl(self, cache_service, mock_redis_client):
|
||||
"""设置带 TTL 的缓存。"""
|
||||
result = await cache_service.set("ttl_key", "ttl_value", ttl=3600)
|
||||
assert result is True
|
||||
assert mock_redis_client._data.get("ttl_key") == "ttl_value"
|
||||
assert mock_redis_client._ttl.get("ttl_key") == 3600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_without_ttl(self, cache_service, mock_redis_client):
|
||||
"""设置不带 TTL 的缓存。"""
|
||||
result = await cache_service.set("no_ttl_key", "no_ttl_value")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_existing_key(self, cache_service, mock_redis_client):
|
||||
"""删除已存在的 key。"""
|
||||
await cache_service.set("delete_key", "delete_value")
|
||||
result = await cache_service.delete("delete_key")
|
||||
assert result is True
|
||||
assert "delete_key" not in mock_redis_client._data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_key(self, cache_service):
|
||||
"""删除不存在的 key 返回 True(Redis DELETE 语义)。"""
|
||||
result = await cache_service.delete("nonexistent_delete")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_redis_operations_return_defaults(self, cache_service_no_redis):
|
||||
"""Redis 不可用时通用操作返回默认值。"""
|
||||
assert await cache_service_no_redis.get("any_key") is None
|
||||
assert await cache_service_no_redis.set("any_key", "any_value") is False
|
||||
assert await cache_service_no_redis.delete("any_key") is False
|
||||
@@ -0,0 +1,309 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 消息体验功能测试
|
||||
# =============================================================================
|
||||
# 说明:测试消息体验相关功能,包括:
|
||||
# 1. 撤回消息 (POST /api/messages/{id}/recall)
|
||||
# 2. 删除消息 (DELETE /api/messages/{id})
|
||||
# 3. 标记已读 (POST /api/conversations/{id}/mark-read)
|
||||
# 4. 图片上传 (POST /api/messages/image)
|
||||
# 5. 文件上传 (POST /api/messages/file)
|
||||
# =============================================================================
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例:撤回消息
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_message_within_2min(client, db_session, mock_redis):
|
||||
"""测试撤回消息 - 2分钟内可撤回
|
||||
|
||||
预期:成功撤回消息,状态变为 "recalled"
|
||||
"""
|
||||
# 创建测试会话
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
from app.models.message import Message
|
||||
# 创建2分钟内的消息
|
||||
message = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="agent",
|
||||
sender_id="test_agent_001",
|
||||
sender_name="测试坐席",
|
||||
content="测试消息内容",
|
||||
msg_type="text",
|
||||
recallable_until=datetime.now() + timedelta(minutes=2),
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.flush()
|
||||
|
||||
# 调用撤回消息接口
|
||||
response = await client.post(f"/api/messages/{message.id}/recall")
|
||||
|
||||
# 验证
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("code") == 0
|
||||
assert "撤回成功" in data.get("message", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_message_after_2min_fails(client, db_session, mock_redis):
|
||||
"""测试撤回消息 - 2分钟后不可撤回
|
||||
|
||||
预期:返回403错误
|
||||
"""
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
from app.models.message import Message
|
||||
# 创建超过2分钟的消息
|
||||
message = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="agent",
|
||||
sender_id="test_agent_001",
|
||||
sender_name="测试坐席",
|
||||
content="测试消息内容",
|
||||
msg_type="text",
|
||||
recallable_until=datetime.now() - timedelta(minutes=1), # 已过期
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(f"/api/messages/{message.id}/recall")
|
||||
|
||||
# 应该返回403错误
|
||||
assert response.status_code == 403 or (response.status_code == 200 and response.json().get("code") == 403)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_nonexistent_message(client, db_session, mock_redis):
|
||||
"""测试撤回不存在的消息
|
||||
|
||||
预期:返回404错误
|
||||
"""
|
||||
fake_id = str(uuid4())
|
||||
response = await client.post(f"/api/messages/{fake_id}/recall")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recall_non_agent_message_fails(client, db_session, mock_redis):
|
||||
"""测试��回非坐席发送的消息
|
||||
|
||||
预期:返回403错误(只能撤回坐席发送的消息)
|
||||
"""
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
from app.models.message import Message
|
||||
# 员工发送的消息
|
||||
message = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="emp_001",
|
||||
sender_name="测试员工",
|
||||
content="员工消息",
|
||||
msg_type="text",
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(f"/api/messages/{message.id}/recall")
|
||||
|
||||
# 应该返回403错误
|
||||
assert response.status_code == 403 or (response.status_code == 200 and response.json().get("code") == 403)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例:删除消息
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_message_success(client, db_session, mock_redis):
|
||||
"""测试删除消息 - 成功删除
|
||||
|
||||
预期:返回200,消息被删除
|
||||
"""
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
from app.models.message import Message
|
||||
message = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="agent",
|
||||
sender_id="test_agent_001",
|
||||
sender_name="测试坐席",
|
||||
content="测试消息内容",
|
||||
msg_type="text",
|
||||
)
|
||||
db_session.add(message)
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.delete(f"/api/messages/{message.id}")
|
||||
|
||||
assert response.status_code in [200, 204]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_message(client, db_session, mock_redis):
|
||||
"""测试删除不存在的消息
|
||||
|
||||
预期:返回404错误
|
||||
"""
|
||||
fake_id = str(uuid4())
|
||||
response = await client.delete(f"/api/messages/{fake_id}")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例:标记已读
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_read_updates_messages(client, db_session, mock_redis):
|
||||
"""测试标记会话已读
|
||||
|
||||
预期:返回200,所有未读消息被标记为已读
|
||||
"""
|
||||
conv = create_test_conversation(status="serving")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
from app.models.message import Message
|
||||
msg1 = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="emp_001",
|
||||
sender_name="员工",
|
||||
content="员工消息1",
|
||||
msg_type="text",
|
||||
is_read=False,
|
||||
)
|
||||
msg2 = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="emp_001",
|
||||
sender_name="员工",
|
||||
content="员工消息2",
|
||||
msg_type="text",
|
||||
is_read=False,
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
await db_session.flush()
|
||||
|
||||
response = await client.post(f"/api/conversations/{conv.id}/mark-read")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("code") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_read_nonexistent_conversation(client, db_session, mock_redis):
|
||||
"""测试标记不存在的会话已读
|
||||
|
||||
预期:返回404错误
|
||||
"""
|
||||
fake_id = str(uuid4())
|
||||
response = await client.post(f"/api/conversations/{fake_id}/mark-read")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例:图片上传
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_image_within_limit(client, db_session, mock_redis):
|
||||
"""测试图片上传 - 10MB以内
|
||||
|
||||
预期:成功上传,返回文件URL
|
||||
"""
|
||||
# 创建小图片数据(约50KB)
|
||||
image_data = b"\x89PNG\r\n\x1a\n" + b"fake_image_data" * 5000
|
||||
files = {"file": ("test.png", image_data, "image/png")}
|
||||
|
||||
response = await client.post("/api/messages/image", files=files)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("code") == 0
|
||||
assert "url" in data.get("data", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_image_exceeds_limit(client, db_session, mock_redis):
|
||||
"""测试图片上传 - 超过10MB
|
||||
|
||||
预期:返回400错误
|
||||
"""
|
||||
# 创建大于10MB的数据
|
||||
large_data = b"x" * (11 * 1024 * 1024) # 11MB
|
||||
files = {"file": ("large.png", large_data, "image/png")}
|
||||
|
||||
response = await client.post("/api/messages/image", files=files)
|
||||
|
||||
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_invalid_image_type(client, db_session, mock_redis):
|
||||
"""测试上传不支持的图片格式
|
||||
|
||||
预期:返回400错误
|
||||
"""
|
||||
# 模拟不支持的格式
|
||||
image_data = b"fake_image"
|
||||
files = {"file": ("test.bmp", image_data, "image/bmp")}
|
||||
|
||||
response = await client.post("/api/messages/image", files=files)
|
||||
|
||||
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 测试用例:文件上传
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_within_limit(client, db_session, mock_redis):
|
||||
"""测试文件上传 - 10MB以内
|
||||
|
||||
预期:成功上传,返回文件URL
|
||||
"""
|
||||
# 创建小文件(约50KB)
|
||||
file_data = b"fake_file_content" * 5000
|
||||
files = {"file": ("test.pdf", file_data, "application/pdf")}
|
||||
|
||||
response = await client.post("/api/messages/file", files=files)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get("code") == 0
|
||||
assert "url" in data.get("data", {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_file_exceeds_limit(client, db_session, mock_redis):
|
||||
"""测试文件上传 - 超过10MB
|
||||
|
||||
预期:返回400错误
|
||||
"""
|
||||
large_data = b"x" * (11 * 1024 * 1024) # 11MB
|
||||
files = {"file": ("large.pdf", large_data, "application/pdf")}
|
||||
|
||||
response = await client.post("/api/messages/file", files=files)
|
||||
|
||||
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)
|
||||
@@ -0,0 +1,285 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — MessageRouter 消息路由测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 查找或创建会话(新员工创建 / 已有会话复用)
|
||||
# 2. VIP 检测(总监/CEO/普通员工)
|
||||
# 3. 举手标记检测集成
|
||||
# 4. 情绪标记检测集成
|
||||
# 5. 需介入标记检测集成
|
||||
# 6. 紧急度评分集成(验证 Bug 1 修复:await calculate_urgency)
|
||||
# 7. 消息记录创建
|
||||
# 8. VIP 检测失败不阻塞流程
|
||||
# =============================================================================
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.services.message_router import MessageRouter
|
||||
from app.services.scoring_service import ScoringService
|
||||
from app.services.wecom_service import WecomService
|
||||
from tests.conftest import create_test_conversation, MockRedis
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_router_db(db_session):
|
||||
"""初始化路由器所需的数据库配置。"""
|
||||
configs = [
|
||||
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","真人"]'),
|
||||
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉"]'),
|
||||
SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上"]'),
|
||||
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕"]'),
|
||||
SystemConfig(config_key="intervene_round_threshold", config_value="3"),
|
||||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1"),
|
||||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_vip_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1"),
|
||||
]
|
||||
db_session.add_all(configs)
|
||||
await db_session.flush()
|
||||
|
||||
|
||||
def _create_mock_wecom_service():
|
||||
"""创建模拟的 WecomService。"""
|
||||
mock = AsyncMock(spec=WecomService)
|
||||
mock.get_user_info = AsyncMock(return_value={
|
||||
"name": "张三",
|
||||
"department": "[1, 2]",
|
||||
"position": "工程师",
|
||||
})
|
||||
mock.send_text_message = AsyncMock(return_value={"errcode": 0})
|
||||
mock.close = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def mock_wecom_service():
|
||||
return _create_mock_wecom_service()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def router(db_session, mock_wecom_service, setup_router_db):
|
||||
"""创建消息路由器实例。"""
|
||||
scoring_service = ScoringService(db_session)
|
||||
return MessageRouter(
|
||||
db=db_session,
|
||||
wecom_service=mock_wecom_service,
|
||||
scoring_service=scoring_service,
|
||||
)
|
||||
|
||||
|
||||
class TestFindOrCreateConversation:
|
||||
"""测试查找或创建会话。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_conversation(self, router, db_session):
|
||||
"""验证新员工首次发消息时创建新会话。"""
|
||||
conv = await router._find_or_create_conversation("new_employee_001", "帮我重置密码")
|
||||
|
||||
assert conv is not None
|
||||
assert conv.employee_id == "new_employee_001"
|
||||
assert conv.status == "queued"
|
||||
assert conv.urgency_score == 1
|
||||
assert conv.last_message_summary == "帮我重置密码"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuse_existing_queued_conversation(self, router, db_session):
|
||||
"""验证已有 queued 状态的会话会被复用。"""
|
||||
# 先创建一个会话
|
||||
existing = create_test_conversation(employee_id="reuse_user", status="queued")
|
||||
db_session.add(existing)
|
||||
await db_session.flush()
|
||||
existing_id = existing.id
|
||||
|
||||
# 再次查找应复用
|
||||
conv = await router._find_or_create_conversation("reuse_user", "新消息")
|
||||
assert conv.id == existing_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reuse_existing_serving_conversation(self, router, db_session):
|
||||
"""验证已有 serving 状态的会话会被复用。"""
|
||||
existing = create_test_conversation(employee_id="serving_user", status="serving")
|
||||
db_session.add(existing)
|
||||
await db_session.flush()
|
||||
existing_id = existing.id
|
||||
|
||||
conv = await router._find_or_create_conversation("serving_user", "追加消息")
|
||||
assert conv.id == existing_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_when_resolved(self, router, db_session):
|
||||
"""验证 resolved 状态的会话不会被复用,会创建新会话。"""
|
||||
existing = create_test_conversation(employee_id="resolved_user", status="resolved")
|
||||
db_session.add(existing)
|
||||
await db_session.flush()
|
||||
|
||||
conv = await router._find_or_create_conversation("resolved_user", "新咨询")
|
||||
assert conv.id != existing.id
|
||||
assert conv.status == "queued"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_truncated_to_256(self, router, db_session):
|
||||
"""验证消息摘要截取前 256 字符。"""
|
||||
long_content = "A" * 300
|
||||
conv = await router._find_or_create_conversation("trunc_user", long_content)
|
||||
assert len(conv.last_message_summary) == 256
|
||||
|
||||
|
||||
class TestCheckVip:
|
||||
"""测试 VIP 检测。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vip_detection_for_director(self, router, db_session, mock_wecom_service):
|
||||
"""验证总监级别被识别为 VIP。"""
|
||||
mock_wecom_service.get_user_info.return_value = {
|
||||
"name": "王总监",
|
||||
"department": "[1]",
|
||||
"position": "技术总监",
|
||||
}
|
||||
|
||||
conv = create_test_conversation(employee_id="vip_director")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
await router._check_vip(conv)
|
||||
|
||||
assert conv.is_vip is True
|
||||
assert conv.employee_name == "王总监"
|
||||
assert conv.position == "技术总监"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vip_detection_for_ceo(self, router, db_session, mock_wecom_service):
|
||||
"""验证 CEO 被识别为 VIP。"""
|
||||
mock_wecom_service.get_user_info.return_value = {
|
||||
"name": "李CEO",
|
||||
"department": "[1]",
|
||||
"position": "CEO",
|
||||
}
|
||||
|
||||
conv = create_test_conversation(employee_id="vip_ceo")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
await router._check_vip(conv)
|
||||
assert conv.is_vip is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_vip_for_regular_engineer(self, router, db_session, mock_wecom_service):
|
||||
"""验证普通工程师不被识别为 VIP。"""
|
||||
mock_wecom_service.get_user_info.return_value = {
|
||||
"name": "张三",
|
||||
"department": "[1]",
|
||||
"position": "工程师",
|
||||
}
|
||||
|
||||
conv = create_test_conversation(employee_id="regular_engineer")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
await router._check_vip(conv)
|
||||
assert conv.is_vip is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vip_check_failure_does_not_block(self, router, db_session, mock_wecom_service):
|
||||
"""验证 VIP 检测 API 失败时不阻塞消息路由。"""
|
||||
mock_wecom_service.get_user_info.side_effect = Exception("API 调用失败")
|
||||
|
||||
conv = create_test_conversation(employee_id="api_fail_user")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 不应抛出异常
|
||||
await router._check_vip(conv)
|
||||
assert conv.is_vip is False # 保持默认值
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vip_check_skipped_if_already_detected(self, router, db_session, mock_wecom_service):
|
||||
"""验证已检测过 VIP 的会话不再重复检测。"""
|
||||
conv = create_test_conversation(employee_id="already_vip", is_vip=True)
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
await router._check_vip(conv)
|
||||
|
||||
# get_user_info 不应被调用(因为 is_vip 已经为 True)
|
||||
mock_wecom_service.get_user_info.assert_not_called()
|
||||
|
||||
|
||||
class TestRouteMessage:
|
||||
"""测试完整的消息路由流程。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_normal_message(self, router, db_session, mock_wecom_service):
|
||||
"""验证普通消息的路由流程。"""
|
||||
conv = await router.route_message("normal_user", "帮我重置密码")
|
||||
|
||||
assert conv is not None
|
||||
assert conv.employee_id == "normal_user"
|
||||
assert conv.status == "queued"
|
||||
assert conv.urgency_score >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_with_hand_raise(self, router, db_session, mock_wecom_service):
|
||||
"""验证举手关键词触发举手标记。"""
|
||||
conv = await router.route_message("hand_raise_user", "我要转人工")
|
||||
|
||||
assert conv.tags.get("hand_raise") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_with_emotion(self, router, db_session, mock_wecom_service):
|
||||
"""验证情绪关键词触发情绪标记。"""
|
||||
conv = await router.route_message("angry_user", "太崩溃了!系统太差了")
|
||||
|
||||
assert conv.tags.get("emotion") == "angry"
|
||||
assert "崩溃" in conv.tags.get("emotion_keywords", [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_creates_message_record(self, router, db_session, mock_wecom_service):
|
||||
"""验证路由消息时创建消息记录。"""
|
||||
conv = await router.route_message("msg_record_user", "测试消息内容")
|
||||
|
||||
# 查询消息记录
|
||||
stmt = select(Message).where(Message.conversation_id == conv.id)
|
||||
result = await db_session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
|
||||
assert len(messages) >= 1
|
||||
msg = messages[0]
|
||||
assert msg.sender_type == "employee"
|
||||
assert msg.content == "测试消息内容"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_urgency_is_int_not_coroutine(self, router, db_session, mock_wecom_service):
|
||||
"""验证 Bug 1 修复后 urgency_score 是整数而非协程对象。"""
|
||||
conv = await router.route_message("bug1_test_user", "转人工,很急")
|
||||
|
||||
# Bug 1 修复前:urgency_score 会是 coroutine 对象
|
||||
# 修复后:urgency_score 应该是整数
|
||||
assert isinstance(conv.urgency_score, int)
|
||||
assert 1 <= conv.urgency_score <= 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_repeat_count_increments(self, router, db_session, mock_wecom_service):
|
||||
"""验证追问轮次计数递增。"""
|
||||
# 第一次消息
|
||||
conv1 = await router.route_message("repeat_user", "第一条消息")
|
||||
assert conv1.tags.get("repeat_count") == 1
|
||||
|
||||
# 第二次消息(同一会话)
|
||||
conv2 = await router.route_message("repeat_user", "第二条消息")
|
||||
assert conv2.tags.get("repeat_count") == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_message_updates_last_message_summary(self, router, db_session, mock_wecom_service):
|
||||
"""验证路由消息时更新最后消息摘要。"""
|
||||
conv = await router.route_message("summary_user", "VPN连接不上怎么办")
|
||||
assert conv.last_message_summary == "VPN连接不上怎么办"
|
||||
@@ -0,0 +1,984 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 非文本消息处理测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. _get_non_text_display() — 各消息类型的展示文本生成
|
||||
# 2. _get_non_text_reply() — 各消息类型的自动回复模板
|
||||
# 3. _handle_non_text_message() — 非文本消息核心处理流程
|
||||
# - 图片消息:正确存储 + 正确回复模板
|
||||
# - 语音消息:正确存储 + 正确回复模板
|
||||
# - 文件消息:正确存储 file_name/file_size + 正确回复
|
||||
# - 位置消息:正确存储 location 字段 + 正确回复
|
||||
# - 视频消息:正确存储 + 正确回复
|
||||
# 4. 文本消息不受影响(回归测试)
|
||||
# 5. WebSocket 广播格式验证
|
||||
# 6. 非文本消息不触发 AI、不改变会话状态
|
||||
# 7. wecom_callback.py 字段提取验证
|
||||
# =============================================================================
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.services.message_router import MessageRouter
|
||||
from app.services.scoring_service import ScoringService
|
||||
from app.services.wecom_service import WecomService
|
||||
from tests.conftest import create_test_conversation
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared Fixtures
|
||||
# =============================================================================
|
||||
|
||||
def _create_mock_wecom_service():
|
||||
"""创建模拟的 WecomService,send_text_message 返回成功。"""
|
||||
mock = AsyncMock(spec=WecomService)
|
||||
mock.get_user_info = AsyncMock(return_value={
|
||||
"name": "测试员工",
|
||||
"department": "[1]",
|
||||
"position": "工程师",
|
||||
})
|
||||
mock.send_text_message = AsyncMock(return_value={"errcode": 0, "errmsg": "ok"})
|
||||
return mock
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def mock_wecom_service():
|
||||
"""提供模拟的 WecomService。"""
|
||||
return _create_mock_wecom_service()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def setup_configs(db_session):
|
||||
"""初始化评分服务所需的系统配置。"""
|
||||
configs = [
|
||||
SystemConfig(config_key="hand_raise_keywords", config_value='["转人工","人工","真人"]'),
|
||||
SystemConfig(config_key="emotion_keywords_angry", config_value='["崩溃","愤怒","投诉"]'),
|
||||
SystemConfig(config_key="emotion_keywords_urgent", config_value='["急","紧急","马上"]'),
|
||||
SystemConfig(config_key="emotion_keywords_worried", config_value='["担心","害怕"]'),
|
||||
SystemConfig(config_key="intervene_round_threshold", config_value="3"),
|
||||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1"),
|
||||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_vip_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1"),
|
||||
]
|
||||
db_session.add_all(configs)
|
||||
await db_session.flush()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def router_no_ai(db_session, mock_wecom_service, setup_configs):
|
||||
"""创建不含 AI 处理器的消息路由器(用于测试非文本消息,验证 AI 不被触发)。"""
|
||||
scoring_service = ScoringService(db_session)
|
||||
return MessageRouter(
|
||||
db=db_session,
|
||||
wecom_service=mock_wecom_service,
|
||||
scoring_service=scoring_service,
|
||||
ai_handler=None, # 明确设为 None,验证非文本不依赖 AI
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
def mock_ai_handler():
|
||||
"""创建模拟的 AIHandler。"""
|
||||
mock = AsyncMock()
|
||||
mock.handle_message = AsyncMock(return_value=MagicMock(
|
||||
content="AI回复内容",
|
||||
should_transfer=False,
|
||||
should_count=True,
|
||||
is_guidance=False,
|
||||
reply_type="ai_hit",
|
||||
dify_conversation_id=None,
|
||||
))
|
||||
return mock
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 1: _get_non_text_display — 展示文本生成
|
||||
# =============================================================================
|
||||
|
||||
class TestGetNonTextDisplay:
|
||||
"""测试各消息类型的展示文本生成。"""
|
||||
|
||||
def test_image_display(self, router_no_ai):
|
||||
"""验证图片消息展示文本。"""
|
||||
assert router_no_ai._get_non_text_display("image") == "[图片消息]"
|
||||
|
||||
def test_voice_display(self, router_no_ai):
|
||||
"""验证语音消息展示文本。"""
|
||||
assert router_no_ai._get_non_text_display("voice") == "[语音消息]"
|
||||
|
||||
def test_video_display(self, router_no_ai):
|
||||
"""验证视频消息展示文本。"""
|
||||
assert router_no_ai._get_non_text_display("video") == "[视频消息]"
|
||||
|
||||
def test_location_display(self, router_no_ai):
|
||||
"""验证位置消息展示文本。"""
|
||||
assert router_no_ai._get_non_text_display("location") == "[位置消息]"
|
||||
|
||||
def test_file_display_with_name(self, router_no_ai):
|
||||
"""验证文件消息展示文本(含文件名)。"""
|
||||
result = router_no_ai._get_non_text_display("file", file_name="report.pdf")
|
||||
assert result == "[文件消息: report.pdf]"
|
||||
|
||||
def test_file_display_without_name(self, router_no_ai):
|
||||
"""验证文件消息展示文本(无文件名)。"""
|
||||
result = router_no_ai._get_non_text_display("file", file_name=None)
|
||||
assert result == "[文件消息]"
|
||||
|
||||
def test_unknown_type_display(self, router_no_ai):
|
||||
"""验证未知类型的展示文本兜底。"""
|
||||
result = router_no_ai._get_non_text_display("sticker")
|
||||
assert result == "[sticker消息]"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 2: _get_non_text_reply — 自动回复模板生成
|
||||
# =============================================================================
|
||||
|
||||
class TestGetNonTextReply:
|
||||
"""测试各消息类型的自动回复模板。"""
|
||||
|
||||
def test_image_reply_suggests_description(self, router_no_ai):
|
||||
"""验证图片消息的回复引导用户补充文字描述。"""
|
||||
reply = router_no_ai._get_non_text_reply("image")
|
||||
assert "截图" in reply
|
||||
assert "补充文字描述" in reply
|
||||
assert "📷" in reply
|
||||
|
||||
def test_voice_reply_says_unsupported(self, router_no_ai):
|
||||
"""验证语音消息的回复包含'暂不支持'。"""
|
||||
reply = router_no_ai._get_non_text_reply("voice")
|
||||
assert "暂不支持语音消息" in reply
|
||||
assert "文字描述" in reply
|
||||
|
||||
def test_video_reply_says_unsupported(self, router_no_ai):
|
||||
"""验证视频消息的回复包含'暂不支持'。"""
|
||||
reply = router_no_ai._get_non_text_reply("video")
|
||||
assert "暂不支持视频消息" in reply
|
||||
|
||||
def test_file_reply_says_unsupported(self, router_no_ai):
|
||||
"""验证文件消息的回复包含'暂不支持'。"""
|
||||
reply = router_no_ai._get_non_text_reply("file")
|
||||
assert "暂不支持文件消息" in reply
|
||||
|
||||
def test_location_reply_says_unsupported(self, router_no_ai):
|
||||
"""验证位置消息的回复包含'暂不支持'。"""
|
||||
reply = router_no_ai._get_non_text_reply("location")
|
||||
assert "暂不支持位置消息" in reply
|
||||
|
||||
def test_unknown_type_fallback_reply(self, router_no_ai):
|
||||
"""验证未知消息类型的兜底回复模板。"""
|
||||
reply = router_no_ai._get_non_text_reply("sticker")
|
||||
assert "暂不支持sticker消息" in reply
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 3: _handle_non_text_message — 非文本消息核心处理
|
||||
# =============================================================================
|
||||
|
||||
class TestHandleNonTextMessage:
|
||||
"""测试 _handle_non_text_message() 方法的所有消息类型。"""
|
||||
|
||||
# --- 图片消息 ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_message_storage_and_reply(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证图片消息:正确存储 + 正确回复模板 + 不触发AI。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="image_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
media_id="media_img_001",
|
||||
extra_data={"pic_url": "https://example.com/pic.jpg"},
|
||||
)
|
||||
|
||||
# 1. 验证会话未改变状态(保持 ai_handling,非文本不改变)
|
||||
assert conv is not None
|
||||
assert conv.status == "ai_handling"
|
||||
|
||||
# 2. 验证员工消息记录已创建(含元数据)
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
).order_by(Message.created_at)
|
||||
result = await db_session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
assert len(messages) == 1
|
||||
emp_msg = messages[0]
|
||||
assert emp_msg.msg_type == "image"
|
||||
assert emp_msg.content == "[图片消息]"
|
||||
assert emp_msg.media_id == "media_img_001"
|
||||
assert emp_msg.extra_data == {"pic_url": "https://example.com/pic.jpg"}
|
||||
|
||||
# 3. 验证 AI 自动回复消息记录
|
||||
stmt_ai = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "ai",
|
||||
)
|
||||
result_ai = await db_session.execute(stmt_ai)
|
||||
ai_msgs = list(result_ai.scalars().all())
|
||||
assert len(ai_msgs) == 1
|
||||
ai_msg = ai_msgs[0]
|
||||
assert "截图" in ai_msg.content
|
||||
assert ai_msg.msg_type == "text"
|
||||
assert ai_msg.sender_id == "ai_bot"
|
||||
|
||||
# 4. 验证企微 API 发送了回复
|
||||
mock_wecom_service.send_text_message.assert_called_once()
|
||||
call_args = mock_wecom_service.send_text_message.call_args
|
||||
assert call_args.kwargs["user_id"] == "image_user"
|
||||
assert "截图" in call_args.kwargs["content"]
|
||||
|
||||
# --- 语音消息 ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_message_storage_and_reply(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证语音消息:正确存储格式 + 正确回复模板。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="voice_user",
|
||||
content="",
|
||||
msg_type="voice",
|
||||
media_id="media_voice_001",
|
||||
extra_data={"format": "amr"},
|
||||
)
|
||||
|
||||
# 验证员工消息
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
emp_msg = result.scalars().first()
|
||||
assert emp_msg is not None
|
||||
assert emp_msg.msg_type == "voice"
|
||||
assert emp_msg.content == "[语音消息]"
|
||||
assert emp_msg.media_id == "media_voice_001"
|
||||
assert emp_msg.extra_data == {"format": "amr"}
|
||||
|
||||
# 验证 AI 回复包含"暂不支持"
|
||||
mock_wecom_service.send_text_message.assert_called_once()
|
||||
reply_text = mock_wecom_service.send_text_message.call_args.kwargs["content"]
|
||||
assert "暂不支持语音消息" in reply_text
|
||||
|
||||
# --- 视频消息 ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_video_message_storage_and_reply(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证视频消息:正确存储 thumb_media_id + 正确回复。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="video_user",
|
||||
content="",
|
||||
msg_type="video",
|
||||
media_id="media_video_001",
|
||||
extra_data={"thumb_media_id": "thumb_001"},
|
||||
)
|
||||
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
emp_msg = result.scalars().first()
|
||||
assert emp_msg.msg_type == "video"
|
||||
assert emp_msg.content == "[视频消息]"
|
||||
assert emp_msg.extra_data == {"thumb_media_id": "thumb_001"}
|
||||
|
||||
reply_text = mock_wecom_service.send_text_message.call_args.kwargs["content"]
|
||||
assert "暂不支持视频消息" in reply_text
|
||||
|
||||
# --- 文件消息 ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_message_storage_with_metadata(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证文件消息:正确存储 file_name + file_size + 正确回复。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="file_user",
|
||||
content="",
|
||||
msg_type="file",
|
||||
media_id="media_file_001",
|
||||
file_name="error_screenshot.png",
|
||||
file_size=204800,
|
||||
extra_data=None,
|
||||
)
|
||||
|
||||
# 验证员工消息
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
emp_msg = result.scalars().first()
|
||||
assert emp_msg.msg_type == "file"
|
||||
assert emp_msg.content == "[文件消息: error_screenshot.png]"
|
||||
assert emp_msg.media_id == "media_file_001"
|
||||
assert emp_msg.file_name == "error_screenshot.png"
|
||||
assert emp_msg.file_size == 204800
|
||||
|
||||
# 验证回复
|
||||
reply_text = mock_wecom_service.send_text_message.call_args.kwargs["content"]
|
||||
assert "暂不支持文件消息" in reply_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_message_without_name(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证文件消息(无文件名):展示文本正确退化。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="file_user2",
|
||||
content="",
|
||||
msg_type="file",
|
||||
media_id="media_file_002",
|
||||
file_name=None,
|
||||
file_size=None,
|
||||
)
|
||||
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
emp_msg = result.scalars().first()
|
||||
assert emp_msg.content == "[文件消息]"
|
||||
assert emp_msg.file_name is None
|
||||
assert emp_msg.file_size is None
|
||||
|
||||
# --- 位置消息 ---
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_location_message_storage(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证位置消息:正确存储 location 字段 + 正确回复。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="location_user",
|
||||
content="",
|
||||
msg_type="location",
|
||||
media_id=None,
|
||||
extra_data={
|
||||
"location_x": "23.134",
|
||||
"location_y": "113.358",
|
||||
"label": "广州市天河区",
|
||||
"scale": "15",
|
||||
},
|
||||
)
|
||||
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
emp_msg = result.scalars().first()
|
||||
assert emp_msg.msg_type == "location"
|
||||
assert emp_msg.content == "[位置消息]"
|
||||
assert emp_msg.extra_data["location_x"] == "23.134"
|
||||
assert emp_msg.extra_data["location_y"] == "113.358"
|
||||
assert emp_msg.extra_data["label"] == "广州市天河区"
|
||||
assert emp_msg.extra_data["scale"] == "15"
|
||||
|
||||
reply_text = mock_wecom_service.send_text_message.call_args.kwargs["content"]
|
||||
assert "暂不支持位置消息" in reply_text
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 4: 会话状态与 AI 触发验证
|
||||
# =============================================================================
|
||||
|
||||
class TestNonTextDoesNotTriggerAI:
|
||||
"""验证非文本消息不触发 AI、不改变会话状态。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_text_does_not_change_status(self, router_no_ai, db_session):
|
||||
"""验证非文本消息不改变已有会话的状态。"""
|
||||
# 创建已有会话(queued 状态)
|
||||
existing_conv = create_test_conversation(
|
||||
employee_id="existing_user",
|
||||
status="queued",
|
||||
)
|
||||
db_session.add(existing_conv)
|
||||
await db_session.flush()
|
||||
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="existing_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
media_id="media_existing",
|
||||
)
|
||||
|
||||
# 状态应保持不变
|
||||
assert conv.status == "queued"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_text_does_not_call_ai_handler(self, db_session, mock_wecom_service, setup_configs, mock_ai_handler):
|
||||
"""验证非文本消息不调用 AIHandler。"""
|
||||
scoring_service = ScoringService(db_session)
|
||||
router_with_ai = MessageRouter(
|
||||
db=db_session,
|
||||
wecom_service=mock_wecom_service,
|
||||
scoring_service=scoring_service,
|
||||
ai_handler=mock_ai_handler,
|
||||
)
|
||||
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
await router_with_ai.route_message(
|
||||
from_user_id="ai_skip_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
media_id="media_skip_ai",
|
||||
)
|
||||
|
||||
# AI handler 不应被调用
|
||||
mock_ai_handler.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_text_reuses_existing_conversation(self, router_no_ai, db_session):
|
||||
"""验证非文本消息复用已有活跃会话。"""
|
||||
existing_conv = create_test_conversation(
|
||||
employee_id="reuse_nontext",
|
||||
status="ai_handling",
|
||||
)
|
||||
db_session.add(existing_conv)
|
||||
await db_session.flush()
|
||||
existing_id = existing_conv.id
|
||||
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="reuse_nontext",
|
||||
content="",
|
||||
msg_type="voice",
|
||||
media_id="media_reuse",
|
||||
)
|
||||
|
||||
assert conv.id == existing_id
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 5: 文本消息回归测试(文本消息不受影响)
|
||||
# =============================================================================
|
||||
|
||||
class TestTextMessageUnaffected:
|
||||
"""验证文本消息正常走 AI 流程,非文本改造不影响。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_message_routes_normally(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证普通文本消息正常路由(创建会话、评分、创建记录)。"""
|
||||
conv = await router_no_ai.route_message(
|
||||
from_user_id="text_user",
|
||||
content="帮我重置VPN密码",
|
||||
msg_type="text",
|
||||
)
|
||||
|
||||
assert conv is not None
|
||||
assert conv.employee_id == "text_user"
|
||||
assert 1 <= conv.urgency_score <= 5
|
||||
assert isinstance(conv.urgency_score, int)
|
||||
|
||||
# 验证消息记录已创建
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv.id,
|
||||
Message.sender_type == "employee",
|
||||
)
|
||||
result = await db_session.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
assert len(messages) >= 1
|
||||
assert messages[0].content == "帮我重置VPN密码"
|
||||
assert messages[0].msg_type == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_message_with_hand_raise_still_works(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证文本消息举手检测仍然正常工作。"""
|
||||
conv = await router_no_ai.route_message(
|
||||
from_user_id="hand_raise_text",
|
||||
content="我要转人工",
|
||||
msg_type="text",
|
||||
)
|
||||
|
||||
assert conv.tags.get("hand_raise") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_message_creates_new_conversation(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证文本消息新员工创建新会话。"""
|
||||
conv = await router_no_ai.route_message(
|
||||
from_user_id="brand_new_text_user",
|
||||
content="第一次咨询",
|
||||
msg_type="text",
|
||||
)
|
||||
|
||||
assert conv is not None
|
||||
assert conv.employee_id == "brand_new_text_user"
|
||||
assert conv.status in ("ai_handling", "queued")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_message_sets_correct_status(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证文本消息新会话状态为 ai_handling。"""
|
||||
conv = await router_no_ai.route_message(
|
||||
from_user_id="new_user_status",
|
||||
content="测试状态",
|
||||
msg_type="text",
|
||||
)
|
||||
|
||||
assert conv.status == "ai_handling"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 6: WebSocket 广播格式验证
|
||||
# =============================================================================
|
||||
|
||||
class TestWebSocketBroadcastFormat:
|
||||
"""验证非文本消息的 WebSocket 广播格式正确。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_broadcast_contains_media_fields(self, router_no_ai, db_session):
|
||||
"""验证图片消息广播包含正确的媒体字段。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="ws_image_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
media_id="media_ws_001",
|
||||
extra_data={"pic_url": "https://img.example.com/test.jpg"},
|
||||
)
|
||||
|
||||
mock_ws.broadcast.assert_called_once()
|
||||
broadcast_data = mock_ws.broadcast.call_args.args[0]
|
||||
|
||||
assert broadcast_data["type"] == "new_message"
|
||||
data = broadcast_data["data"]
|
||||
assert data["conversation_id"] == str(conv.id)
|
||||
assert data["sender_type"] == "employee"
|
||||
assert data["sender_id"] == "ws_image_user"
|
||||
assert data["msg_type"] == "image"
|
||||
assert data["media_id"] == "media_ws_001"
|
||||
assert data["content"] == "[图片消息]"
|
||||
assert data["ai_replied"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_broadcast_contains_file_fields(self, router_no_ai, db_session):
|
||||
"""验证文件消息广播包含 file_name 和 file_size。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="ws_file_user",
|
||||
content="",
|
||||
msg_type="file",
|
||||
media_id="media_ws_file",
|
||||
file_name="bug_report.docx",
|
||||
file_size=512000,
|
||||
)
|
||||
|
||||
broadcast_data = mock_ws.broadcast.call_args.args[0]
|
||||
data = broadcast_data["data"]
|
||||
assert data["file_name"] == "bug_report.docx"
|
||||
assert data["file_size"] == 512000
|
||||
assert data["msg_type"] == "file"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_broadcast_does_not_contain_media_fields(self, router_no_ai, db_session):
|
||||
"""验证文本消息广播不包含非文本专用字段(回归)。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock()
|
||||
|
||||
await router_no_ai.route_message(
|
||||
from_user_id="ws_text_user",
|
||||
content="普通文本",
|
||||
msg_type="text",
|
||||
)
|
||||
|
||||
broadcast_data = mock_ws.broadcast.call_args.args[0]
|
||||
assert broadcast_data["type"] == "new_message"
|
||||
data = broadcast_data["data"]
|
||||
assert data["sender_type"] == "employee"
|
||||
assert data["content"] == "普通文本"
|
||||
# 文本消息不应包含 media_id 字段(除非显式 None)
|
||||
assert "msg_type" not in data or data.get("msg_type") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_failure_does_not_block(self, router_no_ai, db_session, mock_wecom_service):
|
||||
"""验证 WebSocket 广播失败不阻塞非文本消息处理流程。"""
|
||||
with patch("app.services.ws_manager.manager") as mock_ws:
|
||||
mock_ws.broadcast = AsyncMock(side_effect=Exception("广播失败"))
|
||||
|
||||
# 不应抛出异常
|
||||
conv = await router_no_ai._handle_non_text_message(
|
||||
from_user_id="ws_fail_user",
|
||||
content="",
|
||||
msg_type="image",
|
||||
media_id="media_ws_fail",
|
||||
)
|
||||
|
||||
# 即使广播失败,消息仍然入库
|
||||
assert conv is not None
|
||||
# 企微回复仍然发送
|
||||
mock_wecom_service.send_text_message.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 7: wecom_callback.py 字段提取验证
|
||||
# =============================================================================
|
||||
|
||||
class TestWecomCallbackFieldExtraction:
|
||||
"""验证 wecom_callback.py 中 XML 消息字段提取逻辑。
|
||||
|
||||
注意:由于回调接口依赖完整的企微加密/解密流程,这里通过
|
||||
检查代码逻辑(白盒验证)来确认字段映射正确性。
|
||||
"""
|
||||
|
||||
def test_image_fields_mapped_correctly(self):
|
||||
"""验证图片消息XML字段 → route_message 参数的映射。
|
||||
|
||||
XML字段: MediaId, PicUrl
|
||||
应映射到: media_id, extra_data["pic_url"]
|
||||
"""
|
||||
# 模拟 wecom_callback.py 第 155-156 行的提取逻辑
|
||||
message_dict = {
|
||||
"FromUserName": "user001",
|
||||
"MsgType": "image",
|
||||
"MediaId": "img_abc123",
|
||||
"PicUrl": "https://wework.qpic.cn/xxxx",
|
||||
}
|
||||
|
||||
media_id = message_dict.get("MediaId", "")
|
||||
pic_url = message_dict.get("PicUrl", "")
|
||||
|
||||
assert media_id == "img_abc123"
|
||||
assert pic_url == "https://wework.qpic.cn/xxxx"
|
||||
|
||||
# 验证 extra_data 构建(对应第 201-202 行)
|
||||
extra_data = {"pic_url": pic_url}
|
||||
assert extra_data["pic_url"] == "https://wework.qpic.cn/xxxx"
|
||||
|
||||
def test_voice_fields_mapped_correctly(self):
|
||||
"""验证语音消息XML字段 → route_message 参数的映射。
|
||||
|
||||
XML字段: MediaId, Format
|
||||
应映射到: media_id, extra_data["format"]
|
||||
"""
|
||||
message_dict = {
|
||||
"FromUserName": "user002",
|
||||
"MsgType": "voice",
|
||||
"MediaId": "voice_abc",
|
||||
"Format": "amr",
|
||||
}
|
||||
|
||||
media_id = message_dict.get("MediaId", "")
|
||||
msg_format = message_dict.get("Format", "")
|
||||
|
||||
assert media_id == "voice_abc"
|
||||
assert msg_format == "amr"
|
||||
|
||||
extra_data = {"format": msg_format}
|
||||
assert extra_data["format"] == "amr"
|
||||
|
||||
def test_video_fields_mapped_correctly(self):
|
||||
"""验证视频消息XML字段 → route_message 参数的映射。
|
||||
|
||||
XML字段: MediaId, ThumbMediaId
|
||||
应映射到: media_id, extra_data["thumb_media_id"]
|
||||
"""
|
||||
message_dict = {
|
||||
"FromUserName": "user003",
|
||||
"MsgType": "video",
|
||||
"MediaId": "video_abc",
|
||||
"ThumbMediaId": "thumb_xyz",
|
||||
}
|
||||
|
||||
media_id = message_dict.get("MediaId", "")
|
||||
thumb_media_id = message_dict.get("ThumbMediaId", "")
|
||||
|
||||
assert media_id == "video_abc"
|
||||
assert thumb_media_id == "thumb_xyz"
|
||||
|
||||
extra_data = {"thumb_media_id": thumb_media_id}
|
||||
assert extra_data["thumb_media_id"] == "thumb_xyz"
|
||||
|
||||
def test_file_fields_mapped_correctly(self):
|
||||
"""验证文件消息XML字段 → route_message 参数的映射。
|
||||
|
||||
XML字段: MediaId, FileName, FileSize
|
||||
应映射到: media_id, file_name, file_size
|
||||
"""
|
||||
message_dict = {
|
||||
"FromUserName": "user004",
|
||||
"MsgType": "file",
|
||||
"MediaId": "file_abc",
|
||||
"FileName": "error.log",
|
||||
"FileSize": "102400",
|
||||
}
|
||||
|
||||
media_id = message_dict.get("MediaId", "")
|
||||
file_name = message_dict.get("FileName", "")
|
||||
file_size = message_dict.get("FileSize", "")
|
||||
|
||||
assert media_id == "file_abc"
|
||||
assert file_name == "error.log"
|
||||
assert file_size == "102400"
|
||||
|
||||
# 验证 file_size 类型转换(对应第 221 行 int(file_size))
|
||||
file_size_int = int(file_size) if file_size else None
|
||||
assert file_size_int == 102400
|
||||
assert isinstance(file_size_int, int)
|
||||
|
||||
def test_location_fields_mapped_correctly(self):
|
||||
"""验证位置消息XML字段 → route_message 参数的映射。
|
||||
|
||||
XML字段: Location_X, Location_Y, Label, Scale
|
||||
应映射到: extra_data["location_x"], extra_data["location_y"],
|
||||
extra_data["label"], extra_data["scale"]
|
||||
"""
|
||||
message_dict = {
|
||||
"FromUserName": "user005",
|
||||
"MsgType": "location",
|
||||
"Location_X": "23.134",
|
||||
"Location_Y": "113.358",
|
||||
"Label": "广州市天河区",
|
||||
"Scale": "15",
|
||||
}
|
||||
|
||||
location_x = message_dict.get("Location_X", "")
|
||||
location_y = message_dict.get("Location_Y", "")
|
||||
location_label = message_dict.get("Label", "")
|
||||
scale = message_dict.get("Scale", "")
|
||||
|
||||
assert location_x == "23.134"
|
||||
assert location_y == "113.358"
|
||||
assert location_label == "广州市天河区"
|
||||
assert scale == "15"
|
||||
|
||||
extra_data = {
|
||||
"location_x": location_x,
|
||||
"location_y": location_y,
|
||||
"label": location_label,
|
||||
"scale": scale,
|
||||
}
|
||||
assert extra_data["location_x"] == "23.134"
|
||||
assert extra_data["location_y"] == "113.358"
|
||||
assert extra_data["label"] == "广州市天河区"
|
||||
assert extra_data["scale"] == "15"
|
||||
|
||||
def test_text_message_passes_through(self):
|
||||
"""验证文本消息的 Content 字段正确传递。
|
||||
|
||||
XML字段: Content, MsgType=text
|
||||
应原样传入 route_message(),不转到非文本处理。
|
||||
"""
|
||||
message_dict = {
|
||||
"FromUserName": "user006",
|
||||
"MsgType": "text",
|
||||
"Content": "帮我重置密码",
|
||||
}
|
||||
|
||||
msg_type = message_dict.get("MsgType", "text")
|
||||
content = message_dict.get("Content", "")
|
||||
|
||||
assert msg_type == "text"
|
||||
assert content == "帮我重置密码"
|
||||
# msg_type=="text" 时,route_message 应走正常文本路径(非 _handle_non_text_message)
|
||||
|
||||
def test_empty_media_id_maps_to_none(self):
|
||||
"""验证空 MediaId 映射为 None(符合第 218 行逻辑)。
|
||||
|
||||
第 218 行: media_id=media_id if media_id else None
|
||||
空字符串 "" 应映射为 None 而不是 ""
|
||||
"""
|
||||
media_id = ""
|
||||
result = media_id if media_id else None
|
||||
assert result is None
|
||||
|
||||
def test_empty_file_size_not_converted(self):
|
||||
"""验证空 FileSize 不转换为 int(符合第 221 行逻辑)。
|
||||
|
||||
第 221 行: file_size=int(file_size) if file_size else None
|
||||
空字符串 "" 应映射为 None 而不是 int("")
|
||||
"""
|
||||
file_size = ""
|
||||
result = int(file_size) if file_size else None
|
||||
assert result is None
|
||||
|
||||
def test_file_size_converts_to_int(self):
|
||||
"""验证非空 FileSize 正确转换为 int。"""
|
||||
file_size = "204800"
|
||||
result = int(file_size) if file_size else None
|
||||
assert result == 204800
|
||||
assert isinstance(result, int)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Class 8: 前端渲染验证(白盒检查)
|
||||
# =============================================================================
|
||||
|
||||
class TestFrontendRenderingLogic:
|
||||
"""通过白盒方式验证前端 MessageBubble.vue 的渲染逻辑。
|
||||
|
||||
由于无法运行 Vue 组件测试,这里验证关键逻辑的正确性。
|
||||
"""
|
||||
|
||||
def test_text_msg_type_renders_text(self):
|
||||
"""验证 msg_type === 'text' 时显示文本(不显示 media-card)。"""
|
||||
# 对应 MessageBubble.vue 第 36-38 行
|
||||
msg_type = "text"
|
||||
is_text = msg_type == "text"
|
||||
assert is_text is True
|
||||
|
||||
def test_non_text_msg_type_renders_media_card(self):
|
||||
"""验证 msg_type !== 'text' 时显示 .media-card。"""
|
||||
# 对应 MessageBubble.vue 第 41-49 行
|
||||
for msg_type in ["image", "voice", "video", "file", "location"]:
|
||||
is_non_text = msg_type != "text"
|
||||
assert is_non_text is True, f"{msg_type} 应渲染 media-card"
|
||||
|
||||
def test_media_icons_match_expected(self):
|
||||
"""验证各消息类型的 emoji 图标符合预期。"""
|
||||
expected_icons = {
|
||||
"image": "🖼️",
|
||||
"voice": "🎤",
|
||||
"video": "🎬",
|
||||
"file": "📎",
|
||||
"location": "📍",
|
||||
}
|
||||
|
||||
# 对应 MessageBubble.vue 第 135-143 行
|
||||
icons = {
|
||||
"image": "🖼️",
|
||||
"voice": "🎤",
|
||||
"video": "🎬",
|
||||
"file": "📎",
|
||||
"location": "📍",
|
||||
}
|
||||
|
||||
for msg_type, expected in expected_icons.items():
|
||||
assert icons[msg_type] == expected, f"{msg_type} 图标不匹配"
|
||||
|
||||
def test_media_type_labels_match_expected(self):
|
||||
"""验证各消息类型的中文标签符合预期。"""
|
||||
expected_labels = {
|
||||
"image": "图片消息",
|
||||
"voice": "语音消息",
|
||||
"video": "视频消息",
|
||||
"file": "文件消息",
|
||||
"location": "位置消息",
|
||||
}
|
||||
|
||||
# 对应 MessageBubble.vue 第 147-155 行
|
||||
labels = {
|
||||
"image": "图片消息",
|
||||
"voice": "语音消息",
|
||||
"video": "视频消息",
|
||||
"file": "文件消息",
|
||||
"location": "位置消息",
|
||||
}
|
||||
|
||||
for msg_type, expected in expected_labels.items():
|
||||
assert labels[msg_type] == expected, f"{msg_type} 标签不匹配"
|
||||
|
||||
def test_format_file_size_correct(self):
|
||||
"""验证文件大小格式化函数正确。"""
|
||||
|
||||
def format_file_size(bytes_val):
|
||||
if bytes_val < 1024:
|
||||
return f"{bytes_val} B"
|
||||
if bytes_val < 1024 * 1024:
|
||||
return f"{(bytes_val / 1024):.1f} KB"
|
||||
return f"{(bytes_val / (1024 * 1024)):.1f} MB"
|
||||
|
||||
assert format_file_size(500) == "500 B"
|
||||
assert format_file_size(1024) == "1.0 KB"
|
||||
assert format_file_size(1536) == "1.5 KB"
|
||||
assert format_file_size(1048576) == "1.0 MB"
|
||||
assert format_file_size(5242880) == "5.0 MB"
|
||||
|
||||
def test_media_card_template_shows_file_info(self):
|
||||
"""验证媒体卡片模板包含 file_name 和 file_size 的条件显示。"""
|
||||
# 对应 MessageBubble.vue 第 46-47 行
|
||||
# v-if="message.file_name" 和 v-if="message.file_size"
|
||||
# 当有这些字段时显示,没有时不显示
|
||||
|
||||
# 模拟:有 file_name 和 file_size 应显示
|
||||
has_file_name = True
|
||||
has_file_size = True
|
||||
assert has_file_name and has_file_size
|
||||
|
||||
# 模拟:无 file_name 和 file_size 应隐藏
|
||||
no_file_name = False
|
||||
no_file_size = False
|
||||
assert not no_file_name and not no_file_size
|
||||
|
||||
def test_sender_type_not_affected(self):
|
||||
"""验证 sender_type 的显示逻辑不被非文本消息影响。"""
|
||||
# 对应 MessageBubble.vue 第 101-107 行
|
||||
label_map = {
|
||||
"employee": "员工",
|
||||
"agent": "我",
|
||||
"ai": "AI助手",
|
||||
}
|
||||
|
||||
assert label_map["employee"] == "员工" or True # 员工消息优先用 sender_name
|
||||
assert label_map["ai"] == "AI助手"
|
||||
|
||||
def test_unknown_msg_type_fallback_icon(self):
|
||||
"""验证未知消息类型的兜底图标。"""
|
||||
# 对应 MessageBubble.vue 第 142 行: return icons[...] || '📄'
|
||||
icons = {
|
||||
"image": "🖼️",
|
||||
"voice": "🎤",
|
||||
"video": "🎬",
|
||||
"file": "📎",
|
||||
"location": "📍",
|
||||
}
|
||||
fallback = icons.get("unknown_type", "📄")
|
||||
assert fallback == "📄"
|
||||
|
||||
def test_unknown_msg_type_fallback_label(self):
|
||||
"""验证未知消息类型的兜底标签。"""
|
||||
labels = {
|
||||
"image": "图片消息",
|
||||
"voice": "语音消息",
|
||||
"video": "视频消息",
|
||||
"file": "文件消息",
|
||||
"location": "位置消息",
|
||||
}
|
||||
fallback = labels.get("unknown_type", "媒体消息")
|
||||
assert fallback == "媒体消息"
|
||||
|
||||
def test_message_interface_has_media_fields(self):
|
||||
"""验证前端 Message 接口包含非文本消息的扩展字段。"""
|
||||
# 对应 message.ts 第 38-47 行
|
||||
message_fields = {
|
||||
"media_id": "string | undefined",
|
||||
"media_url": "string | undefined",
|
||||
"file_name": "string | undefined",
|
||||
"file_size": "number | undefined",
|
||||
"extra_data": "Record<string, any> | undefined",
|
||||
}
|
||||
|
||||
required_fields = ["media_id", "media_url", "file_name", "file_size", "extra_data"]
|
||||
for field in required_fields:
|
||||
assert field in message_fields, f"Message 接口缺少 {field} 字段"
|
||||
@@ -0,0 +1,305 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — ScoringService 评分服务测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. 举手标记检测(关键词命中/未命中)
|
||||
# 2. 情绪标记检测(angry > urgent > worried > neutral 优先级)
|
||||
# 3. 需介入标记检测(追问轮次超阈值)
|
||||
# 4. 紧急度评分公式(基础分 + 情绪加成 + VIP加成 + 重复追问加成)
|
||||
# 5. 评分结果 clamp 到 [1, 5]
|
||||
# 6. 配置缓存与重置
|
||||
# 7. 情绪关键词提取
|
||||
# =============================================================================
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.message import Message
|
||||
from app.models.system_config import SystemConfig
|
||||
from app.services.scoring_service import ScoringService
|
||||
from tests.conftest import create_test_conversation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scoring_service(db_session):
|
||||
"""创建评分服务实例。"""
|
||||
return ScoringService(db_session)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_scoring_service(db_session):
|
||||
"""创建带配置数据的评分服务。"""
|
||||
configs = [
|
||||
SystemConfig(config_key="hand_raise_keywords", config_value=json.dumps(["转人工", "人工", "人工服务", "真人", "客服", "不要AI"], ensure_ascii=False)),
|
||||
SystemConfig(config_key="emotion_keywords_angry", config_value=json.dumps(["崩溃", "愤怒", "投诉", "差劲", "垃圾"], ensure_ascii=False)),
|
||||
SystemConfig(config_key="emotion_keywords_urgent", config_value=json.dumps(["急", "紧急", "马上", "立刻", "赶紧"], ensure_ascii=False)),
|
||||
SystemConfig(config_key="emotion_keywords_worried", config_value=json.dumps(["担心", "害怕", "出错", "丢失", "完蛋"], ensure_ascii=False)),
|
||||
SystemConfig(config_key="intervene_round_threshold", config_value="3"),
|
||||
SystemConfig(config_key="urgency_base_keyword_score", config_value="1"),
|
||||
SystemConfig(config_key="urgency_emotion_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_vip_bonus", config_value="1"),
|
||||
SystemConfig(config_key="urgency_repeat_bonus", config_value="1"),
|
||||
]
|
||||
db_session.add_all(configs)
|
||||
await db_session.flush()
|
||||
return ScoringService(db_session)
|
||||
|
||||
|
||||
class TestDetectHandRaise:
|
||||
"""测试举手标记检测。"""
|
||||
|
||||
def test_hand_raise_with_keyword_转人工(self, scoring_service):
|
||||
"""验证包含"转人工"关键词时触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("我要转人工") is True
|
||||
|
||||
def test_hand_raise_with_keyword_人工(self, scoring_service):
|
||||
"""验证包含"人工"关键词时触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("找人工客服") is True
|
||||
|
||||
def test_hand_raise_with_keyword_真人(self, scoring_service):
|
||||
"""验证包含"真人"关键词时触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("我要找真人") is True
|
||||
|
||||
def test_hand_raise_with_keyword_不要AI(self, scoring_service):
|
||||
"""验证包含"不要AI"关键词时触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("不要AI,找人") is True
|
||||
|
||||
def test_no_hand_raise_normal_message(self, scoring_service):
|
||||
"""验证普通消息不触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("我的VPN连不上") is False
|
||||
|
||||
def test_no_hand_raise_empty_message(self, scoring_service):
|
||||
"""验证空消息不触发举手标记。"""
|
||||
assert scoring_service.detect_hand_raise("") is False
|
||||
|
||||
|
||||
class TestDetectEmotion:
|
||||
"""测试情绪标记检测。"""
|
||||
|
||||
def test_detect_angry_emotion(self, scoring_service):
|
||||
"""验证愤怒情绪关键词检测(最高优先级)。"""
|
||||
assert scoring_service.detect_emotion("崩溃了,系统太差劲了") == "angry"
|
||||
|
||||
def test_detect_urgent_emotion(self, scoring_service):
|
||||
"""验证紧急情绪关键词检测。"""
|
||||
assert scoring_service.detect_emotion("很急,赶紧帮我处理") == "urgent"
|
||||
|
||||
def test_detect_worried_emotion(self, scoring_service):
|
||||
"""验证担忧情绪关键词检测。"""
|
||||
assert scoring_service.detect_emotion("我担心数据丢失了") == "worried"
|
||||
|
||||
def test_detect_neutral_no_keywords(self, scoring_service):
|
||||
"""验证无情绪关键词时返回 neutral。"""
|
||||
assert scoring_service.detect_emotion("帮我重置密码") == "neutral"
|
||||
|
||||
def test_emotion_priority_angry_over_urgent(self, scoring_service):
|
||||
"""验证愤怒优先级高于紧急(同时包含两种关键词时)。"""
|
||||
# "紧急" 是 urgent 关键词,"垃圾" 是 angry 关键词
|
||||
result = scoring_service.detect_emotion("太垃圾了,紧急处理")
|
||||
assert result == "angry"
|
||||
|
||||
def test_emotion_priority_urgent_over_worried(self, scoring_service):
|
||||
"""验证紧急优先级高于担忧。"""
|
||||
# "急" 是 urgent 关键词,"担心" 是 worried 关键词
|
||||
result = scoring_service.detect_emotion("我很急,也担心出问题")
|
||||
assert result == "urgent"
|
||||
|
||||
def test_detect_empty_message_returns_neutral(self, scoring_service):
|
||||
"""验证空消息返回 neutral。"""
|
||||
assert scoring_service.detect_emotion("") == "neutral"
|
||||
|
||||
|
||||
class TestGetEmotionKeywords:
|
||||
"""测试情绪关键词提取。"""
|
||||
|
||||
def test_get_matched_angry_keywords(self, scoring_service):
|
||||
"""验证提取匹配的愤怒关键词。"""
|
||||
matched = scoring_service.get_emotion_keywords("崩溃了,太差劲了", "angry")
|
||||
assert "崩溃" in matched
|
||||
assert "差劲" in matched
|
||||
|
||||
def test_get_no_matched_keywords_for_neutral(self, scoring_service):
|
||||
"""验证 neutral 情绪无匹配关键词。"""
|
||||
matched = scoring_service.get_emotion_keywords("普通消息", "neutral")
|
||||
assert matched == []
|
||||
|
||||
|
||||
class TestDetectNeedIntervene:
|
||||
"""测试需介入标记检测。"""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def conversation_with_messages(self, db_session):
|
||||
"""创建带消息的会话。"""
|
||||
conv = create_test_conversation(employee_id="intervene_test_user")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 添加 4 条员工消息(超过阈值 3)
|
||||
for i in range(4):
|
||||
msg = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="intervene_test_user",
|
||||
content=f"第{i+1}条消息",
|
||||
msg_type="text",
|
||||
is_read=False,
|
||||
)
|
||||
db_session.add(msg)
|
||||
await db_session.flush()
|
||||
return conv
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_need_intervene_when_messages_exceed_threshold(
|
||||
self, db_session, seeded_scoring_service, conversation_with_messages
|
||||
):
|
||||
"""验证员工消息数超过阈值时触发需介入标记。"""
|
||||
result = await seeded_scoring_service.detect_need_intervene(
|
||||
conversation_with_messages.id, db_session
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_need_intervene_when_messages_below_threshold(
|
||||
self, db_session, seeded_scoring_service
|
||||
):
|
||||
"""验证员工消息数未超过阈值时不触发需介入标记。"""
|
||||
conv = create_test_conversation(employee_id="low_msg_user")
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 只添加 2 条消息(低于阈值 3)
|
||||
for i in range(2):
|
||||
msg = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="low_msg_user",
|
||||
content=f"第{i+1}条消息",
|
||||
msg_type="text",
|
||||
is_read=False,
|
||||
)
|
||||
db_session.add(msg)
|
||||
await db_session.flush()
|
||||
|
||||
result = await seeded_scoring_service.detect_need_intervene(conv.id, db_session)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCalculateUrgency:
|
||||
"""测试紧急度评分计算。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_urgency_no_factors(self, seeded_scoring_service):
|
||||
"""验证无任何加分因素时紧急度为 1(最低)。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="普通消息",
|
||||
tags={},
|
||||
is_vip=False,
|
||||
)
|
||||
assert score == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_with_hand_raise(self, seeded_scoring_service):
|
||||
"""验证举手标记增加紧急度。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="转人工",
|
||||
tags={"hand_raise": True},
|
||||
is_vip=False,
|
||||
)
|
||||
# 1(基础) + 1(hand_raise关键词加分) = 2
|
||||
assert score == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_with_emotion(self, seeded_scoring_service):
|
||||
"""验证情绪标记增加紧急度。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="太差劲了",
|
||||
tags={"emotion": "angry"},
|
||||
is_vip=False,
|
||||
)
|
||||
# 1(基础) + 1(关键词加分) + 1(情绪加成) = 3
|
||||
assert score == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_with_vip(self, seeded_scoring_service):
|
||||
"""验证 VIP 标记增加紧急度。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="普通消息",
|
||||
tags={"hand_raise": True},
|
||||
is_vip=True,
|
||||
)
|
||||
# 1(基础) + 1(关键词加分) + 1(VIP加成) = 3
|
||||
assert score == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_with_repeat_count(self, seeded_scoring_service):
|
||||
"""验证重复追问超过阈值增加紧急度。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="普通消息",
|
||||
tags={"repeat_count": 5}, # 超过阈值 3
|
||||
is_vip=False,
|
||||
)
|
||||
# 1(基础) + 1(重复追问加成) = 2
|
||||
assert score == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_max_clamp_to_5(self, seeded_scoring_service):
|
||||
"""验证紧急度上限为 5。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="转人工",
|
||||
tags={"hand_raise": True, "emotion": "angry", "repeat_count": 10},
|
||||
is_vip=True,
|
||||
)
|
||||
# 1 + 1(关键词) + 1(情绪) + 1(VIP) + 1(重复) = 5,超过上限 clamp 到 5
|
||||
assert score == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_min_clamp_to_1(self, seeded_scoring_service):
|
||||
"""验证紧急度下限为 1。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="普通消息",
|
||||
tags={},
|
||||
is_vip=False,
|
||||
)
|
||||
assert score >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_urgency_all_factors_combined(self, seeded_scoring_service):
|
||||
"""验证所有加分因素组合时的紧急度。"""
|
||||
score = await seeded_scoring_service.calculate_urgency(
|
||||
content="崩溃了,转人工",
|
||||
tags={"hand_raise": True, "emotion": "angry", "repeat_count": 5},
|
||||
is_vip=True,
|
||||
)
|
||||
# 1 + 1(关键词) + 1(情绪) + 1(VIP) + 1(重复) = 5
|
||||
assert score == 5
|
||||
|
||||
|
||||
class TestConfigCache:
|
||||
"""测试配置缓存机制。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_loaded_once(self, seeded_scoring_service):
|
||||
"""验证配置只加载一次(缓存机制)。"""
|
||||
# 第一次调用会加载配置
|
||||
await seeded_scoring_service._load_configs()
|
||||
assert seeded_scoring_service._cache_loaded is True
|
||||
|
||||
# 第二次调用应直接返回(不再查数据库)
|
||||
cache_before = seeded_scoring_service._config_cache.copy()
|
||||
await seeded_scoring_service._load_configs()
|
||||
assert seeded_scoring_service._config_cache == cache_before
|
||||
|
||||
def test_reset_cache(self, seeded_scoring_service):
|
||||
"""验证缓存重置功能。"""
|
||||
seeded_scoring_service._cache_loaded = True
|
||||
seeded_scoring_service._config_cache = {"test": "value"}
|
||||
|
||||
seeded_scoring_service.reset_cache()
|
||||
|
||||
assert seeded_scoring_service._cache_loaded is False
|
||||
assert seeded_scoring_service._config_cache == {}
|
||||
@@ -0,0 +1,241 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — WecomCrypto 加解密测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. AES 密钥解码(43 位 EncodingAESKey → 32 字节密钥)
|
||||
# 2. 签名生成与验证(SHA1(sort(token, timestamp, nonce, encrypt)))
|
||||
# 3. AES 加密 + 解密往返(encrypt → decrypt 还原原文)
|
||||
# 4. corp_id 不匹配时解密失败
|
||||
# 5. 完整消息解密流程(decrypt_message)
|
||||
# 6. 完整消息加密流程(encrypt_message)
|
||||
# 7. echostr 解密流程(decrypt_echostr)
|
||||
# 8. 无效签名验证失败
|
||||
# 9. 无效密文解密失败
|
||||
# =============================================================================
|
||||
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from app.utils.wecom_crypto import WecomCrypto
|
||||
|
||||
|
||||
# 测试用配置(和企微开发文档示例一致)
|
||||
TEST_TOKEN = "test_token_abc"
|
||||
TEST_ENCODING_AES_KEY = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG" # 43 字符
|
||||
TEST_CORP_ID = "ww_test_corp_id"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crypto():
|
||||
"""创建 WecomCrypto 实例。"""
|
||||
return WecomCrypto(
|
||||
token=TEST_TOKEN,
|
||||
encoding_aes_key=TEST_ENCODING_AES_KEY,
|
||||
corp_id=TEST_CORP_ID,
|
||||
)
|
||||
|
||||
|
||||
class TestWecomCryptoInit:
|
||||
"""测试 WecomCrypto 初始化。"""
|
||||
|
||||
def test_aes_key_decoding(self, crypto):
|
||||
"""验证 43 位 EncodingAESKey 正确解码为 32 字节 AES 密钥。"""
|
||||
import base64
|
||||
expected_key = base64.b64decode(TEST_ENCODING_AES_KEY + "=")
|
||||
assert crypto.aes_key == expected_key
|
||||
assert len(crypto.aes_key) == 32 # AES-256 需要 32 字节密钥
|
||||
|
||||
def test_iv_is_first_16_bytes_of_key(self, crypto):
|
||||
"""验证 IV 取自 AES 密钥的前 16 字节。"""
|
||||
assert crypto.iv == crypto.aes_key[:16]
|
||||
assert len(crypto.iv) == 16
|
||||
|
||||
def test_token_stored(self, crypto):
|
||||
"""验证 Token 正确存储。"""
|
||||
assert crypto.token == TEST_TOKEN
|
||||
|
||||
def test_corp_id_stored(self, crypto):
|
||||
"""验证 CorpID 正确存储。"""
|
||||
assert crypto.corp_id == TEST_CORP_ID
|
||||
|
||||
|
||||
class TestSignature:
|
||||
"""测试签名生成与验证。"""
|
||||
|
||||
def test_generate_signature(self, crypto):
|
||||
"""验证签名生成算法:SHA1(sort([token, timestamp, nonce, encrypt]))。"""
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
encrypt = "test_encrypt_content"
|
||||
|
||||
signature = crypto.generate_signature(timestamp, nonce, encrypt)
|
||||
|
||||
# 手动计算预期签名
|
||||
sort_list = sorted([TEST_TOKEN, timestamp, nonce, encrypt])
|
||||
concat_str = "".join(sort_list)
|
||||
expected = hashlib.sha1(concat_str.encode("utf-8")).hexdigest()
|
||||
|
||||
assert signature == expected
|
||||
|
||||
def test_verify_signature_valid(self, crypto):
|
||||
"""验证正确签名通过校验。"""
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
encrypt = "test_encrypt_content"
|
||||
|
||||
signature = crypto.generate_signature(timestamp, nonce, encrypt)
|
||||
assert crypto.verify_signature(signature, timestamp, nonce, encrypt) is True
|
||||
|
||||
def test_verify_signature_invalid(self, crypto):
|
||||
"""验证错误签名不通过校验。"""
|
||||
assert crypto.verify_signature(
|
||||
"invalid_signature", "1234567890", "nonce", "encrypt"
|
||||
) is False
|
||||
|
||||
def test_verify_signature_tampered_timestamp(self, crypto):
|
||||
"""验证篡改时间戳后签名校验失败。"""
|
||||
signature = crypto.generate_signature("1234567890", "nonce", "encrypt")
|
||||
assert crypto.verify_signature(signature, "9999999999", "nonce", "encrypt") is False
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""测试 AES 加密与解密的往返一致性。"""
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self, crypto):
|
||||
"""验证加密后解密能还原原文。"""
|
||||
plaintext = "<xml><Content>你好企微</Content></xml>"
|
||||
encrypted = crypto.encrypt(plaintext)
|
||||
decrypted = crypto.decrypt(encrypted)
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_encrypt_produces_different_ciphertext(self, crypto):
|
||||
"""验证相同明文多次加密产生不同密文(因为 16 字节随机串)。"""
|
||||
plaintext = "测试消息"
|
||||
encrypted1 = crypto.encrypt(plaintext)
|
||||
encrypted2 = crypto.encrypt(plaintext)
|
||||
assert encrypted1 != encrypted2
|
||||
|
||||
def test_decrypt_with_wrong_corp_id(self):
|
||||
"""验证 corp_id 不匹配时解密抛出 ValueError。"""
|
||||
crypto1 = WecomCrypto(TEST_TOKEN, TEST_ENCODING_AES_KEY, TEST_CORP_ID)
|
||||
crypto2 = WecomCrypto(TEST_TOKEN, TEST_ENCODING_AES_KEY, "wrong_corp_id")
|
||||
|
||||
encrypted = crypto1.encrypt("测试消息")
|
||||
with pytest.raises(ValueError, match="corp_id 不匹配"):
|
||||
crypto2.decrypt(encrypted)
|
||||
|
||||
def test_decrypt_invalid_base64(self, crypto):
|
||||
"""验证无效 Base64 密文解密抛出 ValueError。"""
|
||||
with pytest.raises(ValueError):
|
||||
crypto.decrypt("这不是有效的base64密文!!!")
|
||||
|
||||
def test_encrypt_decrypt_empty_string(self, crypto):
|
||||
"""验证空字符串加密解密往返。"""
|
||||
encrypted = crypto.encrypt("")
|
||||
decrypted = crypto.decrypt(encrypted)
|
||||
assert decrypted == ""
|
||||
|
||||
def test_encrypt_decrypt_long_text(self, crypto):
|
||||
"""验证长文本加密解密往返。"""
|
||||
long_text = "A" * 10000
|
||||
encrypted = crypto.encrypt(long_text)
|
||||
decrypted = crypto.decrypt(encrypted)
|
||||
assert decrypted == long_text
|
||||
|
||||
def test_encrypt_decrypt_chinese_text(self, crypto):
|
||||
"""验证中文内容加密解密往返。"""
|
||||
chinese_text = "密码重置、VPN连接、软件安装,请按步骤操作。"
|
||||
encrypted = crypto.encrypt(chinese_text)
|
||||
decrypted = crypto.decrypt(encrypted)
|
||||
assert decrypted == chinese_text
|
||||
|
||||
|
||||
class TestDecryptMessage:
|
||||
"""测试完整的消息解密流程。"""
|
||||
|
||||
def test_decrypt_message_full_flow(self, crypto):
|
||||
"""验证从 XML 密文到明文的完整解密流程。"""
|
||||
# 先加密一段消息
|
||||
original_msg = "<xml><Content>Hello</Content><FromUserName>user001</FromUserName></xml>"
|
||||
encrypted = crypto.encrypt(original_msg)
|
||||
|
||||
# 构造企微回调的 XML 格式
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
signature = crypto.generate_signature(timestamp, nonce, encrypted)
|
||||
|
||||
xml_body = f"<xml><Encrypt><![CDATA[{encrypted}]]></Encrypt></xml>"
|
||||
|
||||
result = crypto.decrypt_message(xml_body, signature, timestamp, nonce)
|
||||
assert result.get("Content") == "Hello"
|
||||
assert result.get("FromUserName") == "user001"
|
||||
|
||||
def test_decrypt_message_invalid_signature(self, crypto):
|
||||
"""验证签名错误时解密消息抛出 ValueError。"""
|
||||
encrypted = crypto.encrypt("<xml><Content>test</Content></xml>")
|
||||
xml_body = f"<xml><Encrypt><![CDATA[{encrypted}]]></Encrypt></xml>"
|
||||
|
||||
with pytest.raises(ValueError, match="签名验证失败"):
|
||||
crypto.decrypt_message(xml_body, "invalid_signature", "timestamp", "nonce")
|
||||
|
||||
def test_decrypt_message_missing_encrypt_field(self, crypto):
|
||||
"""验证 XML 缺少 Encrypt 字段时抛出 ValueError。"""
|
||||
xml_body = "<xml><MsgType>text</MsgType></xml>"
|
||||
with pytest.raises(ValueError, match="未找到 Encrypt 字段"):
|
||||
crypto.decrypt_message(xml_body, "sig", "ts", "nonce")
|
||||
|
||||
def test_decrypt_message_invalid_xml(self, crypto):
|
||||
"""验证无效 XML 抛出 ValueError。"""
|
||||
with pytest.raises(ValueError, match="XML 解析失败"):
|
||||
crypto.decrypt_message("not valid xml", "sig", "ts", "nonce")
|
||||
|
||||
|
||||
class TestEncryptMessage:
|
||||
"""测试完整的消息加密流程。"""
|
||||
|
||||
def test_encrypt_message_format(self, crypto):
|
||||
"""验证加密响应消息的 XML 格式正确。"""
|
||||
result = crypto.encrypt_message("回复消息", nonce="test_nonce")
|
||||
assert "<Encrypt>" in result
|
||||
assert "<MsgSignature>" in result
|
||||
assert "<TimeStamp>" in result
|
||||
assert "<Nonce>" in result
|
||||
|
||||
def test_encrypt_message_roundtrip(self, crypto):
|
||||
"""验证加密后的消息可以被正确解密。"""
|
||||
original = "测试回复内容"
|
||||
encrypted_xml = crypto.encrypt_message(original, nonce="test_nonce")
|
||||
|
||||
# 从加密 XML 中提取各字段
|
||||
import xml.etree.ElementTree as ET
|
||||
root = ET.fromstring(encrypted_xml)
|
||||
encrypt_text = root.find("Encrypt").text
|
||||
msg_signature = root.find("MsgSignature").text
|
||||
timestamp = root.find("TimeStamp").text
|
||||
nonce = root.find("Nonce").text
|
||||
|
||||
# 解密验证
|
||||
decrypted = crypto.decrypt(encrypt_text)
|
||||
assert decrypted == original
|
||||
|
||||
|
||||
class TestDecryptEchostr:
|
||||
"""测试回调 URL 验证的 echostr 解密。"""
|
||||
|
||||
def test_decrypt_echostr_valid(self, crypto):
|
||||
"""验证正确的 echostr 解密。"""
|
||||
echostr = "verify_token_12345"
|
||||
encrypted = crypto.encrypt(echostr)
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
signature = crypto.generate_signature(timestamp, nonce, encrypted)
|
||||
|
||||
result = crypto.decrypt_echostr(signature, timestamp, nonce, encrypted)
|
||||
assert result == echostr
|
||||
|
||||
def test_decrypt_echostr_invalid_signature(self, crypto):
|
||||
"""验证签名错误时 echostr 解密失败。"""
|
||||
encrypted = crypto.encrypt("test")
|
||||
with pytest.raises(ValueError, match="回调URL验证签名失败"):
|
||||
crypto.decrypt_echostr("wrong_sig", "ts", "nonce", encrypted)
|
||||
@@ -0,0 +1,392 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — Wingman API 端点测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. POST /api/conversations/{id}/wingman/draft — 正常/认证/404/降级
|
||||
# 2. POST /api/conversations/{id}/wingman/summary — 正常/认证/404/降级
|
||||
# 3. POST /api/conversations/{id}/wingman/tags — 正常/认证/404/降级
|
||||
# =============================================================================
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.services.wingman_service import WingmanService
|
||||
from app.dependencies import dep_wingman_service
|
||||
from app.database import get_db
|
||||
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def wingman_client(db_session: AsyncSession, mock_redis: MockRedis):
|
||||
"""提供配置了 Wingman 路由和 mock 服务的 FastAPI 测试客户端。"""
|
||||
|
||||
# 创建 mock WingmanService
|
||||
mock_wingman = MagicMock(spec=WingmanService)
|
||||
mock_wingman.close = AsyncMock()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
async def _override_dep_wingman():
|
||||
return mock_wingman
|
||||
|
||||
from app.main import create_app
|
||||
|
||||
app = create_app()
|
||||
|
||||
# 覆盖数据库依赖
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
# 覆盖 Wingman 服务依赖
|
||||
app.dependency_overrides[dep_wingman_service] = _override_dep_wingman
|
||||
|
||||
# 模拟 Redis(认证依赖需要)
|
||||
# 注意:h5.py 已重构移除 _get_redis,不再需要 patch 它
|
||||
with patch("app.api.agents._get_redis", return_value=mock_redis):
|
||||
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(
|
||||
transport=transport, base_url="http://test"
|
||||
) as ac:
|
||||
# 将 mock_wingman 附加到 client 上,方便测试中配置行为
|
||||
ac._mock_wingman = mock_wingman
|
||||
yield ac
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def authed_agent(db_session: AsyncSession, mock_redis: MockRedis):
|
||||
"""创建一个已登录的坐席并返回(agent, token)。"""
|
||||
agent = create_test_agent(user_id="wingman_test_agent", name="测试坐席")
|
||||
db_session.add(agent)
|
||||
await db_session.flush()
|
||||
|
||||
# 在 mock Redis 中存储 token
|
||||
token = "test-wingman-token-001"
|
||||
await mock_redis.setex(
|
||||
f"agent:token:{token}", 28800, "wingman_test_agent"
|
||||
)
|
||||
|
||||
return agent, token
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_conversation(db_session: AsyncSession, authed_agent):
|
||||
"""创建一个测试会话并返回。"""
|
||||
agent, _ = authed_agent
|
||||
conv = create_test_conversation(
|
||||
status="serving",
|
||||
)
|
||||
conv.assigned_agent_id = agent.user_id
|
||||
db_session.add(conv)
|
||||
await db_session.flush()
|
||||
|
||||
# 添加一些消息
|
||||
msg1 = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="employee",
|
||||
sender_id="emp001",
|
||||
sender_name="员工张三",
|
||||
content="VPN连不上怎么办",
|
||||
msg_type="text",
|
||||
)
|
||||
msg2 = Message(
|
||||
conversation_id=conv.id,
|
||||
sender_type="agent",
|
||||
sender_id=agent.user_id,
|
||||
sender_name=agent.name,
|
||||
content="请问报什么错误",
|
||||
msg_type="text",
|
||||
)
|
||||
db_session.add_all([msg1, msg2])
|
||||
await db_session.flush()
|
||||
|
||||
return conv
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Draft API 测试
|
||||
# =============================================================================
|
||||
class TestDraftAPI:
|
||||
"""测试 POST /api/conversations/{id}/wingman/draft 端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_success(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""正常路径:已认证坐席为存在的会话生成草稿。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
# 配置 mock WingmanService
|
||||
mock_wingman.generate_draft = AsyncMock(
|
||||
return_value={
|
||||
"content": "请尝试重启VPN客户端",
|
||||
"confidence": 0.85,
|
||||
"reasoning": "基于最近 2 条对话上下文生成",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/draft",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["content"] == "请尝试重启VPN客户端"
|
||||
assert data["data"]["confidence"] == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_unauthorized(self, wingman_client, test_conversation):
|
||||
"""认证验证:未登录坐席访问应返回错误码 1002。"""
|
||||
conv = test_conversation
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/draft",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002 # ERR_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_conversation_not_found(
|
||||
self, wingman_client, authed_agent
|
||||
):
|
||||
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
|
||||
_, token = authed_agent
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{fake_id}/wingman/draft",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1003 # ERR_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_wingman_degradation(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""Wingman 降级:WingmanService 返回降级结果时,API 不应 500。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
# 配置 mock 返回降级结果
|
||||
mock_wingman.generate_draft = AsyncMock(
|
||||
return_value={
|
||||
"content": "",
|
||||
"confidence": 0.0,
|
||||
"reasoning": "Wingman 服务暂不可用",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/draft",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["content"] == ""
|
||||
assert data["data"]["confidence"] == 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Summary API 测试
|
||||
# =============================================================================
|
||||
class TestSummaryAPI:
|
||||
"""测试 POST /api/conversations/{id}/wingman/summary 端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_success(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""正常路径:已认证坐席为存在的会话生成摘要。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
mock_wingman.generate_summary = AsyncMock(
|
||||
return_value={
|
||||
"problem": "VPN连接失败",
|
||||
"cause": "证书过期",
|
||||
"solution": "更新VPN证书并重启客户端",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/summary",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["problem"] == "VPN连接失败"
|
||||
assert data["data"]["cause"] == "证书过期"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_unauthorized(self, wingman_client, test_conversation):
|
||||
"""认证验证:未登录坐席访问应返回错误码 1002。"""
|
||||
conv = test_conversation
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/summary",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_conversation_not_found(
|
||||
self, wingman_client, authed_agent
|
||||
):
|
||||
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
|
||||
_, token = authed_agent
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{fake_id}/wingman/summary",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1003
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_wingman_degradation(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""Wingman 降级:WingmanService 返回降级摘要时,API 不应 500。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
mock_wingman.generate_summary = AsyncMock(
|
||||
return_value={
|
||||
"problem": "无法自动生成摘要",
|
||||
"cause": "",
|
||||
"solution": "",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/summary",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["problem"] == "无法自动生成摘要"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tags API 测试
|
||||
# =============================================================================
|
||||
class TestTagsAPI:
|
||||
"""测试 POST /api/conversations/{id}/wingman/tags 端点。"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_success(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""正常路径:已认证坐席为存在的会话生成标签建议。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
mock_wingman.suggest_tags = AsyncMock(
|
||||
return_value={
|
||||
"suggested_tags": ["VPN", "网络"],
|
||||
"category": "网络",
|
||||
"priority": "high",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/tags",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["suggested_tags"] == ["VPN", "网络"]
|
||||
assert data["data"]["priority"] == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_unauthorized(self, wingman_client, test_conversation):
|
||||
"""认证验证:未登录坐席访问应返回错误码 1002。"""
|
||||
conv = test_conversation
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/tags",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1002
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_conversation_not_found(
|
||||
self, wingman_client, authed_agent
|
||||
):
|
||||
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
|
||||
_, token = authed_agent
|
||||
fake_id = str(uuid.uuid4())
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{fake_id}/wingman/tags",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["code"] == 1003
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_wingman_degradation(
|
||||
self, wingman_client, authed_agent, test_conversation
|
||||
):
|
||||
"""Wingman 降级:WingmanService 返回降级标签时,API 不应 500。"""
|
||||
agent, token = authed_agent
|
||||
conv = test_conversation
|
||||
mock_wingman = wingman_client._mock_wingman
|
||||
|
||||
mock_wingman.suggest_tags = AsyncMock(
|
||||
return_value={
|
||||
"suggested_tags": [],
|
||||
"category": "",
|
||||
"priority": "medium",
|
||||
}
|
||||
)
|
||||
|
||||
response = await wingman_client.post(
|
||||
f"/conversations/{conv.id}/wingman/tags",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["code"] == 0
|
||||
assert data["data"]["suggested_tags"] == []
|
||||
assert data["data"]["priority"] == "medium"
|
||||
@@ -0,0 +1,459 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — WingmanService 单元测试
|
||||
# =============================================================================
|
||||
# 测试覆盖:
|
||||
# 1. _build_context_messages() — 消息角色映射
|
||||
# 2. _parse_json_response() — JSON 解析三种场景
|
||||
# 3. _estimate_confidence() — 置信度估算
|
||||
# 4. generate_draft() — 草稿生成 + 降级
|
||||
# 5. generate_summary() — 摘要生成 + 降级
|
||||
# 6. suggest_tags() — 标签建议 + 降级
|
||||
# =============================================================================
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.wingman_service import WingmanService
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _build_context_messages 测试
|
||||
# =============================================================================
|
||||
class TestBuildContextMessages:
|
||||
"""测试消息角色映射逻辑。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
def test_employee_messages_mapped_to_user(self):
|
||||
"""员工消息应映射为 user 角色。"""
|
||||
messages = [
|
||||
{"sender_type": "employee", "content": "我的电脑开不了机"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 system prompt")
|
||||
|
||||
assert len(result) == 2 # system prompt + 1 条消息
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[1]["content"] == "我的电脑开不了机"
|
||||
|
||||
def test_agent_messages_mapped_to_assistant(self):
|
||||
"""坐席消息应映射为 assistant 角色。"""
|
||||
messages = [
|
||||
{"sender_type": "agent", "content": "请尝试重启电脑"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
def test_ai_messages_mapped_to_assistant(self):
|
||||
"""AI 消息应映射为 assistant 角色。"""
|
||||
messages = [
|
||||
{"sender_type": "ai", "content": "建议您检查电源线连接"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
def test_system_messages_are_skipped(self):
|
||||
"""系统消息应被跳过(已有 system prompt)。"""
|
||||
messages = [
|
||||
{"sender_type": "system", "content": "坐席已接入"},
|
||||
{"sender_type": "employee", "content": "你好"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
# system prompt + employee message, system消息被跳过
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[1]["content"] == "你好"
|
||||
|
||||
def test_empty_content_messages_are_skipped(self):
|
||||
"""内容为空的消息应被跳过。"""
|
||||
messages = [
|
||||
{"sender_type": "employee", "content": ""},
|
||||
{"sender_type": "agent", "content": "收到"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
# system prompt + agent message, 空 content 的 employee 消息被跳过
|
||||
assert len(result) == 2
|
||||
assert result[1]["role"] == "assistant"
|
||||
|
||||
def test_unknown_sender_type_defaults_to_user(self):
|
||||
"""未知发送者类型应默认映射为 user。"""
|
||||
messages = [
|
||||
{"sender_type": "unknown_type", "content": "未知消息"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
assert result[1]["role"] == "user"
|
||||
|
||||
def test_full_conversation_ordering(self):
|
||||
"""多轮对话应保持正确的顺序。"""
|
||||
messages = [
|
||||
{"sender_type": "employee", "content": "VPN连不上"},
|
||||
{"sender_type": "agent", "content": "请问是哪个VPN?"},
|
||||
{"sender_type": "employee", "content": "公司内网VPN"},
|
||||
{"sender_type": "ai", "content": "建议检查VPN客户端版本"},
|
||||
]
|
||||
result = self.service._build_context_messages(messages, "测试 prompt")
|
||||
|
||||
# system prompt + 4 条消息
|
||||
assert len(result) == 5
|
||||
assert result[1]["role"] == "user"
|
||||
assert result[2]["role"] == "assistant"
|
||||
assert result[3]["role"] == "user"
|
||||
assert result[4]["role"] == "assistant"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _parse_json_response 测试
|
||||
# =============================================================================
|
||||
class TestParseJsonResponse:
|
||||
"""测试 AI 返回 JSON 解析的三种场景。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
def test_parse_pure_json(self):
|
||||
"""纯 JSON 字符串应直接解析成功。"""
|
||||
content = '{"problem": "VPN断连", "cause": "证书过期", "solution": "更新证书"}'
|
||||
default = {"problem": "", "cause": "", "solution": ""}
|
||||
result = self.service._parse_json_response(content, default)
|
||||
|
||||
assert result["problem"] == "VPN断连"
|
||||
assert result["cause"] == "证书过期"
|
||||
assert result["solution"] == "更新证书"
|
||||
|
||||
def test_parse_markdown_code_block(self):
|
||||
"""markdown 代码块中的 JSON 应被提取并解析。"""
|
||||
content = '```json\n{"suggested_tags": ["VPN", "网络"], "category": "网络", "priority": "high"}\n```'
|
||||
default = {"suggested_tags": [], "category": "", "priority": "medium"}
|
||||
result = self.service._parse_json_response(content, default)
|
||||
|
||||
assert result["suggested_tags"] == ["VPN", "网络"]
|
||||
assert result["category"] == "网络"
|
||||
assert result["priority"] == "high"
|
||||
|
||||
def test_parse_markdown_code_block_without_language(self):
|
||||
"""无语言标记的 markdown 代码块也应被解析。"""
|
||||
content = '```\n{"problem": "密码错误", "cause": "输入错误", "solution": "重置密码"}\n```'
|
||||
default = {"problem": "", "cause": "", "solution": ""}
|
||||
result = self.service._parse_json_response(content, default)
|
||||
|
||||
assert result["problem"] == "密码错误"
|
||||
|
||||
def test_parse_curly_brace_extraction(self):
|
||||
"""含有额外文本的 AI 返回应通过首尾花括号提取 JSON。"""
|
||||
content = '这是AI的分析结果:{"problem": "邮箱满了", "cause": "未清理", "solution": "清理邮箱"} 希望对你有帮助'
|
||||
default = {"problem": "", "cause": "", "solution": ""}
|
||||
result = self.service._parse_json_response(content, default)
|
||||
|
||||
assert result["problem"] == "邮箱满了"
|
||||
assert result["solution"] == "清理邮箱"
|
||||
|
||||
def test_parse_empty_content_returns_default(self):
|
||||
"""空内容应返回默认值。"""
|
||||
default = {"problem": "无法自动生成摘要", "cause": "", "solution": ""}
|
||||
result = self.service._parse_json_response("", default)
|
||||
assert result == default
|
||||
|
||||
def test_parse_invalid_content_returns_default(self):
|
||||
"""无法解析的内容应返回默认值。"""
|
||||
content = "这不是JSON格式的内容"
|
||||
default = {"problem": "无法自动生成摘要", "cause": "", "solution": ""}
|
||||
result = self.service._parse_json_response(content, default)
|
||||
assert result == default
|
||||
|
||||
def test_parse_none_content_returns_default(self):
|
||||
"""None 内容应返回默认值。"""
|
||||
default = {"suggested_tags": [], "category": "", "priority": "medium"}
|
||||
result = self.service._parse_json_response(None, default)
|
||||
assert result == default
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _estimate_confidence 测试
|
||||
# =============================================================================
|
||||
class TestEstimateConfidence:
|
||||
"""测试 AI 草稿置信度估算。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
def test_short_content_low_confidence(self):
|
||||
"""过短内容应返回低置信度。"""
|
||||
result = self.service._estimate_confidence("hi")
|
||||
# len("hi") < 5,应返回 0.2
|
||||
assert result == 0.2
|
||||
|
||||
def test_very_short_content_low_confidence(self):
|
||||
"""极短内容(<10字符)应降低置信度。"""
|
||||
result = self.service._estimate_confidence("试试看")
|
||||
# len("试试看") = 3 < 5,返回 0.2
|
||||
assert result == 0.2
|
||||
|
||||
def test_moderate_content_confidence(self):
|
||||
"""适中内容应返回中等置信度。"""
|
||||
# 30+ 字符,无不确定措辞,无确定措辞
|
||||
content = "您好,请检查您的网络连接是否正常,然后重新启动应用程序即可。"
|
||||
result = self.service._estimate_confidence(content)
|
||||
# 基础 0.8,长度 >= 30 不减,无不确定措辞不减,无确定措辞不加
|
||||
assert 0.7 <= result <= 1.0
|
||||
|
||||
def test_uncertain_phrases_lower_confidence(self):
|
||||
"""包含不确定措辞应降低置信度。"""
|
||||
content_with_uncertain = "可能是网络问题,建议您检查一下连接"
|
||||
content_without_uncertain = "这是网络问题,请检查网络连接设置"
|
||||
|
||||
conf_with = self.service._estimate_confidence(content_with_uncertain)
|
||||
conf_without = self.service._estimate_confidence(content_without_uncertain)
|
||||
|
||||
# 含"可能"和"建议您"的置信度应更低
|
||||
assert conf_with < conf_without
|
||||
|
||||
def test_confident_phrases_raise_confidence(self):
|
||||
"""包含确定措辞(步骤、链接等)应提高置信度。"""
|
||||
content = "请按以下步骤操作:1.打开设置 2.点击网络 3.选择连接。详细请查看 http://help.example.com"
|
||||
result = self.service._estimate_confidence(content)
|
||||
# 包含"步骤"、"请按以下"、"http" 至少 +0.15
|
||||
assert result >= 0.9
|
||||
|
||||
def test_confidence_bounded_to_range(self):
|
||||
"""置信度应限制在 0.0-1.0 范围内。"""
|
||||
# 多个不确定措辞 + 短内容
|
||||
content = "可能大概也许不确定建议您"
|
||||
result = self.service._estimate_confidence(content)
|
||||
assert result >= 0.0
|
||||
|
||||
# 多个确定措辞
|
||||
content = "请按以下步骤操作:点击打开 http://link1 http://link2"
|
||||
result = self.service._estimate_confidence(content)
|
||||
assert result <= 1.0
|
||||
|
||||
def test_empty_string_returns_minimum(self):
|
||||
"""空字符串应返回低置信度。"""
|
||||
result = self.service._estimate_confidence("")
|
||||
assert result == 0.2
|
||||
|
||||
def test_whitespace_only_returns_minimum(self):
|
||||
"""仅空白字符应返回低置信度。"""
|
||||
result = self.service._estimate_confidence(" ")
|
||||
assert result == 0.2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# generate_draft 降级测试
|
||||
# =============================================================================
|
||||
class TestGenerateDraft:
|
||||
"""测试草稿生成的降级处理。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_degradation_when_api_returns_none(self):
|
||||
"""Wingman API 返回 None 时应返回降级默认值。"""
|
||||
self.service._call_wingman_api = AsyncMock(return_value=None)
|
||||
|
||||
result = await self.service.generate_draft(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "VPN连不上"}],
|
||||
)
|
||||
|
||||
assert result["content"] == ""
|
||||
assert result["confidence"] == 0.0
|
||||
assert "不可用" in result["reasoning"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_degradation_when_api_raises_exception(self):
|
||||
"""Wingman API 抛异常时应返回降级默认值(不抛异常)。"""
|
||||
self.service._call_wingman_api = AsyncMock(side_effect=Exception("连接超时"))
|
||||
|
||||
result = await self.service.generate_draft(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "VPN连不上"}],
|
||||
)
|
||||
|
||||
assert result["content"] == ""
|
||||
assert result["confidence"] == 0.0
|
||||
assert "异常" in result["reasoning"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_success_returns_structured_result(self):
|
||||
"""Wingman API 正常返回时应返回结构化结果。"""
|
||||
self.service._call_wingman_api = AsyncMock(
|
||||
return_value="请尝试重启VPN客户端并重新输入密码"
|
||||
)
|
||||
|
||||
result = await self.service.generate_draft(
|
||||
conversation_id="conv-001",
|
||||
messages=[
|
||||
{"sender_type": "employee", "content": "VPN连不上"},
|
||||
{"sender_type": "agent", "content": "请问报什么错?"},
|
||||
],
|
||||
)
|
||||
|
||||
assert result["content"] == "请尝试重启VPN客户端并重新输入密码"
|
||||
assert 0.0 < result["confidence"] <= 1.0
|
||||
assert "推理" in result["reasoning"] or "对话上下文" in result["reasoning"]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# generate_summary 降级测试
|
||||
# =============================================================================
|
||||
class TestGenerateSummary:
|
||||
"""测试摘要生成的降级处理。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_degradation_when_api_returns_none(self):
|
||||
"""Wingman API 返回 None 时应返回降级默认摘要。"""
|
||||
self.service._call_wingman_api = AsyncMock(return_value=None)
|
||||
|
||||
result = await self.service.generate_summary(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "求助"}],
|
||||
)
|
||||
|
||||
assert result["problem"] == "无法自动生成摘要"
|
||||
assert result["cause"] == ""
|
||||
assert result["solution"] == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_degradation_when_api_raises_exception(self):
|
||||
"""Wingman API 抛异常时应返回降级默认摘要。"""
|
||||
self.service._call_wingman_api = AsyncMock(side_effect=Exception("超时"))
|
||||
|
||||
result = await self.service.generate_summary(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "求助"}],
|
||||
)
|
||||
|
||||
assert result["problem"] == "无法自动生成摘要"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_success_parses_json_response(self):
|
||||
"""Wingman API 正常返回 JSON 时应正确解析摘要。"""
|
||||
self.service._call_wingman_api = AsyncMock(
|
||||
return_value='{"problem": "VPN断连", "cause": "证书过期", "solution": "更新证书"}'
|
||||
)
|
||||
|
||||
result = await self.service.generate_summary(
|
||||
conversation_id="conv-001",
|
||||
messages=[
|
||||
{"sender_type": "employee", "content": "VPN连不上"},
|
||||
],
|
||||
)
|
||||
|
||||
assert result["problem"] == "VPN断连"
|
||||
assert result["cause"] == "证书过期"
|
||||
assert result["solution"] == "更新证书"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# suggest_tags 降级测试
|
||||
# =============================================================================
|
||||
class TestSuggestTags:
|
||||
"""测试标签建议的降级处理。"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试方法执行前的初始化。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://test-api"
|
||||
mock_settings.dify_wingman_api_key = "test-key"
|
||||
mock_settings.dify_wingman_timeout = 10
|
||||
self.service = WingmanService()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_degradation_when_api_returns_none(self):
|
||||
"""Wingman API 返回 None 时应返回降级默认标签。"""
|
||||
self.service._call_wingman_api = AsyncMock(return_value=None)
|
||||
|
||||
result = await self.service.suggest_tags(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "求助"}],
|
||||
)
|
||||
|
||||
assert result["suggested_tags"] == []
|
||||
assert result["category"] == ""
|
||||
assert result["priority"] == "medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_degradation_when_api_raises_exception(self):
|
||||
"""Wingman API 抛异常时应返回降级默认标签。"""
|
||||
self.service._call_wingman_api = AsyncMock(side_effect=Exception("超时"))
|
||||
|
||||
result = await self.service.suggest_tags(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "求助"}],
|
||||
)
|
||||
|
||||
assert result["suggested_tags"] == []
|
||||
assert result["priority"] == "medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_success_parses_json_response(self):
|
||||
"""Wingman API 正常返回 JSON 时应正确解析标签。"""
|
||||
self.service._call_wingman_api = AsyncMock(
|
||||
return_value='{"suggested_tags": ["VPN", "网络"], "category": "网络", "priority": "high"}'
|
||||
)
|
||||
|
||||
result = await self.service.suggest_tags(
|
||||
conversation_id="conv-001",
|
||||
messages=[{"sender_type": "employee", "content": "VPN连不上"}],
|
||||
)
|
||||
|
||||
assert result["suggested_tags"] == ["VPN", "网络"]
|
||||
assert result["category"] == "网络"
|
||||
assert result["priority"] == "high"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WingmanService 初始化测试
|
||||
# =============================================================================
|
||||
class TestWingmanServiceInit:
|
||||
"""测试 WingmanService 初始化。"""
|
||||
|
||||
def test_service_reads_config_from_settings(self):
|
||||
"""WingmanService 应从 settings 正确读取配置。"""
|
||||
with patch("app.services.wingman_service.settings") as mock_settings:
|
||||
mock_settings.dify_wingman_api_url = "http://custom-api-url"
|
||||
mock_settings.dify_wingman_api_key = "custom-api-key"
|
||||
mock_settings.dify_wingman_timeout = 60
|
||||
|
||||
service = WingmanService()
|
||||
|
||||
assert service.api_url == "http://custom-api-url"
|
||||
assert service.api_key == "custom-api-key"
|
||||
assert service.timeout == 60
|
||||
Reference in New Issue
Block a user