Files
wecom_it_smart_desk/backend/tests/test_wingman.py
T

393 lines
13 KiB
Python
Raw Normal View History

# =============================================================================
# 企微IT智能服务台 — Wingman API 端点测试
# =============================================================================
# 测试覆盖:
# 1. POST /api/conversations/{id}/wingman/draft — 正常/认证/404/降级
# 2. POST /api/conversations/{id}/wingman/summary — 正常/认证/404/降级
# 3. POST /api/conversations/{id}/wingman/tags — 正常/认证/404/降级
# =============================================================================
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.services.wingman_service import WingmanService
from app.dependencies import dep_wingman_service
from app.database import get_db
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
# =============================================================================
# Fixtures
# =============================================================================
@pytest_asyncio.fixture
async def wingman_client(db_session: AsyncSession, mock_redis: MockRedis):
"""提供配置了 Wingman 路由和 mock 服务的 FastAPI 测试客户端。"""
# 创建 mock WingmanService
mock_wingman = MagicMock(spec=WingmanService)
mock_wingman.close = AsyncMock()
async def _override_get_db():
yield db_session
async def _override_dep_wingman():
return mock_wingman
from app.main import create_app
app = create_app()
# 覆盖数据库依赖
app.dependency_overrides[get_db] = _override_get_db
# 覆盖 Wingman 服务依赖
app.dependency_overrides[dep_wingman_service] = _override_dep_wingman
# 模拟 Redis(认证依赖需要)
# 注意:h5.py 已重构移除 _get_redis,不再需要 patch 它
with patch("app.api.agents._get_redis", return_value=mock_redis):
with patch("redis.asyncio.from_url", return_value=mock_redis):
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://test"
) as ac:
# 将 mock_wingman 附加到 client 上,方便测试中配置行为
ac._mock_wingman = mock_wingman
yield ac
app.dependency_overrides.clear()
@pytest_asyncio.fixture
async def authed_agent(db_session: AsyncSession, mock_redis: MockRedis):
"""创建一个已登录的坐席并返回(agent, token)。"""
agent = create_test_agent(user_id="wingman_test_agent", name="测试坐席")
db_session.add(agent)
await db_session.flush()
# 在 mock Redis 中存储 token
token = "test-wingman-token-001"
await mock_redis.setex(
f"agent:token:{token}", 28800, "wingman_test_agent"
)
return agent, token
@pytest_asyncio.fixture
async def test_conversation(db_session: AsyncSession, authed_agent):
"""创建一个测试会话并返回。"""
agent, _ = authed_agent
conv = create_test_conversation(
status="serving",
)
conv.assigned_agent_id = agent.user_id
db_session.add(conv)
await db_session.flush()
# 添加一些消息
msg1 = Message(
conversation_id=conv.id,
sender_type="employee",
sender_id="emp001",
sender_name="员工张三",
content="VPN连不上怎么办",
msg_type="text",
)
msg2 = Message(
conversation_id=conv.id,
sender_type="agent",
sender_id=agent.user_id,
sender_name=agent.name,
content="请问报什么错误",
msg_type="text",
)
db_session.add_all([msg1, msg2])
await db_session.flush()
return conv
# =============================================================================
# Draft API 测试
# =============================================================================
class TestDraftAPI:
"""测试 POST /api/conversations/{id}/wingman/draft 端点。"""
@pytest.mark.asyncio
async def test_draft_success(
self, wingman_client, authed_agent, test_conversation
):
"""正常路径:已认证坐席为存在的会话生成草稿。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
# 配置 mock WingmanService
mock_wingman.generate_draft = AsyncMock(
return_value={
"content": "请尝试重启VPN客户端",
"confidence": 0.85,
"reasoning": "基于最近 2 条对话上下文生成",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/draft",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["content"] == "请尝试重启VPN客户端"
assert data["data"]["confidence"] == 0.85
@pytest.mark.asyncio
async def test_draft_unauthorized(self, wingman_client, test_conversation):
"""认证验证:未登录坐席访问应返回错误码 1002。"""
conv = test_conversation
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/draft",
)
data = response.json()
assert data["code"] == 1002 # ERR_UNAUTHORIZED
@pytest.mark.asyncio
async def test_draft_conversation_not_found(
self, wingman_client, authed_agent
):
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
_, token = authed_agent
fake_id = str(uuid.uuid4())
response = await wingman_client.post(
f"/conversations/{fake_id}/wingman/draft",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert data["code"] == 1003 # ERR_NOT_FOUND
@pytest.mark.asyncio
async def test_draft_wingman_degradation(
self, wingman_client, authed_agent, test_conversation
):
"""Wingman 降级:WingmanService 返回降级结果时,API 不应 500。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
# 配置 mock 返回降级结果
mock_wingman.generate_draft = AsyncMock(
return_value={
"content": "",
"confidence": 0.0,
"reasoning": "Wingman 服务暂不可用",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/draft",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["content"] == ""
assert data["data"]["confidence"] == 0.0
# =============================================================================
# Summary API 测试
# =============================================================================
class TestSummaryAPI:
"""测试 POST /api/conversations/{id}/wingman/summary 端点。"""
@pytest.mark.asyncio
async def test_summary_success(
self, wingman_client, authed_agent, test_conversation
):
"""正常路径:已认证坐席为存在的会话生成摘要。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
mock_wingman.generate_summary = AsyncMock(
return_value={
"problem": "VPN连接失败",
"cause": "证书过期",
"solution": "更新VPN证书并重启客户端",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/summary",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["problem"] == "VPN连接失败"
assert data["data"]["cause"] == "证书过期"
@pytest.mark.asyncio
async def test_summary_unauthorized(self, wingman_client, test_conversation):
"""认证验证:未登录坐席访问应返回错误码 1002。"""
conv = test_conversation
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/summary",
)
data = response.json()
assert data["code"] == 1002
@pytest.mark.asyncio
async def test_summary_conversation_not_found(
self, wingman_client, authed_agent
):
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
_, token = authed_agent
fake_id = str(uuid.uuid4())
response = await wingman_client.post(
f"/conversations/{fake_id}/wingman/summary",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert data["code"] == 1003
@pytest.mark.asyncio
async def test_summary_wingman_degradation(
self, wingman_client, authed_agent, test_conversation
):
"""Wingman 降级:WingmanService 返回降级摘要时,API 不应 500。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
mock_wingman.generate_summary = AsyncMock(
return_value={
"problem": "无法自动生成摘要",
"cause": "",
"solution": "",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/summary",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["problem"] == "无法自动生成摘要"
# =============================================================================
# Tags API 测试
# =============================================================================
class TestTagsAPI:
"""测试 POST /api/conversations/{id}/wingman/tags 端点。"""
@pytest.mark.asyncio
async def test_tags_success(
self, wingman_client, authed_agent, test_conversation
):
"""正常路径:已认证坐席为存在的会话生成标签建议。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
mock_wingman.suggest_tags = AsyncMock(
return_value={
"suggested_tags": ["VPN", "网络"],
"category": "网络",
"priority": "high",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/tags",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["suggested_tags"] == ["VPN", "网络"]
assert data["data"]["priority"] == "high"
@pytest.mark.asyncio
async def test_tags_unauthorized(self, wingman_client, test_conversation):
"""认证验证:未登录坐席访问应返回错误码 1002。"""
conv = test_conversation
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/tags",
)
data = response.json()
assert data["code"] == 1002
@pytest.mark.asyncio
async def test_tags_conversation_not_found(
self, wingman_client, authed_agent
):
"""会话不存在:传入不存在的 conversation_id 应返回错误码 1003。"""
_, token = authed_agent
fake_id = str(uuid.uuid4())
response = await wingman_client.post(
f"/conversations/{fake_id}/wingman/tags",
headers={"Authorization": f"Bearer {token}"},
)
data = response.json()
assert data["code"] == 1003
@pytest.mark.asyncio
async def test_tags_wingman_degradation(
self, wingman_client, authed_agent, test_conversation
):
"""Wingman 降级:WingmanService 返回降级标签时,API 不应 500。"""
agent, token = authed_agent
conv = test_conversation
mock_wingman = wingman_client._mock_wingman
mock_wingman.suggest_tags = AsyncMock(
return_value={
"suggested_tags": [],
"category": "",
"priority": "medium",
}
)
response = await wingman_client.post(
f"/conversations/{conv.id}/wingman/tags",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["code"] == 0
assert data["data"]["suggested_tags"] == []
assert data["data"]["priority"] == "medium"