Files

309 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — 消息体验功能测试
# =============================================================================
# 说明:测试消息体验相关功能,包括:
# 1. 撤回消息 (POST /api/messages/{id}/recall)
# 2. 删除消息 (DELETE /api/messages/{id})
# 3. 标记已读 (POST /api/conversations/{id}/mark-read)
# 4. 图片上传 (POST /api/messages/image)
# 5. 文件上传 (POST /api/messages/file)
# =============================================================================
import pytest
import pytest_asyncio
from datetime import datetime, timedelta
from uuid import uuid4
from tests.conftest import create_test_conversation, create_test_agent, MockRedis
# =============================================================================
# 测试用例:撤回消息
# =============================================================================
@pytest.mark.asyncio
async def test_recall_message_within_2min(client, db_session, mock_redis):
"""测试撤回消息 - 2分钟内可撤回
预期:成功撤回消息,状态变为 "recalled"
"""
# 创建测试会话
conv = create_test_conversation(status="serving")
db_session.add(conv)
await db_session.flush()
from app.models.message import Message
# 创建2分钟内的消息
message = Message(
conversation_id=conv.id,
sender_type="agent",
sender_id="test_agent_001",
sender_name="测试坐席",
content="测试消息内容",
msg_type="text",
recallable_until=datetime.now() + timedelta(minutes=2),
)
db_session.add(message)
await db_session.flush()
# 调用撤回消息接口
response = await client.post(f"/api/messages/{message.id}/recall")
# 验证
assert response.status_code == 200
data = response.json()
assert data.get("code") == 0
assert "撤回成功" in data.get("message", "")
@pytest.mark.asyncio
async def test_recall_message_after_2min_fails(client, db_session, mock_redis):
"""测试撤回消息 - 2分钟后不可撤回
预期:返回403错误
"""
conv = create_test_conversation(status="serving")
db_session.add(conv)
await db_session.flush()
from app.models.message import Message
# 创建超过2分钟的消息
message = Message(
conversation_id=conv.id,
sender_type="agent",
sender_id="test_agent_001",
sender_name="测试坐席",
content="测试消息内容",
msg_type="text",
recallable_until=datetime.now() - timedelta(minutes=1), # 已过期
)
db_session.add(message)
await db_session.flush()
response = await client.post(f"/api/messages/{message.id}/recall")
# 应该返回403错误
assert response.status_code == 403 or (response.status_code == 200 and response.json().get("code") == 403)
@pytest.mark.asyncio
async def test_recall_nonexistent_message(client, db_session, mock_redis):
"""测试撤回不存在的消息
预期:返回404错误
"""
fake_id = str(uuid4())
response = await client.post(f"/api/messages/{fake_id}/recall")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_recall_non_agent_message_fails(client, db_session, mock_redis):
"""测试回非坐席发送的消息
预期:返回403错误(只能撤回坐席发送的消息)
"""
conv = create_test_conversation(status="serving")
db_session.add(conv)
await db_session.flush()
from app.models.message import Message
# 员工发送的消息
message = Message(
conversation_id=conv.id,
sender_type="employee",
sender_id="emp_001",
sender_name="测试员工",
content="员工消息",
msg_type="text",
)
db_session.add(message)
await db_session.flush()
response = await client.post(f"/api/messages/{message.id}/recall")
# 应该返回403错误
assert response.status_code == 403 or (response.status_code == 200 and response.json().get("code") == 403)
# =============================================================================
# 测试用例:删除消息
# =============================================================================
@pytest.mark.asyncio
async def test_delete_message_success(client, db_session, mock_redis):
"""测试删除消息 - 成功删除
预期:返回200,消息被删除
"""
conv = create_test_conversation(status="serving")
db_session.add(conv)
await db_session.flush()
from app.models.message import Message
message = Message(
conversation_id=conv.id,
sender_type="agent",
sender_id="test_agent_001",
sender_name="测试坐席",
content="测试消息内容",
msg_type="text",
)
db_session.add(message)
await db_session.flush()
response = await client.delete(f"/api/messages/{message.id}")
assert response.status_code in [200, 204]
@pytest.mark.asyncio
async def test_delete_nonexistent_message(client, db_session, mock_redis):
"""测试删除不存在的消息
预期:返回404错误
"""
fake_id = str(uuid4())
response = await client.delete(f"/api/messages/{fake_id}")
assert response.status_code == 404
# =============================================================================
# 测试用例:标记已读
# =============================================================================
@pytest.mark.asyncio
async def test_mark_read_updates_messages(client, db_session, mock_redis):
"""测试标记会话已读
预期:返回200,所有未读消息被标记为已读
"""
conv = create_test_conversation(status="serving")
db_session.add(conv)
await db_session.flush()
from app.models.message import Message
msg1 = Message(
conversation_id=conv.id,
sender_type="employee",
sender_id="emp_001",
sender_name="员工",
content="员工消息1",
msg_type="text",
is_read=False,
)
msg2 = Message(
conversation_id=conv.id,
sender_type="employee",
sender_id="emp_001",
sender_name="员工",
content="员工消息2",
msg_type="text",
is_read=False,
)
db_session.add_all([msg1, msg2])
await db_session.flush()
response = await client.post(f"/api/conversations/{conv.id}/mark-read")
assert response.status_code == 200
data = response.json()
assert data.get("code") == 0
@pytest.mark.asyncio
async def test_mark_read_nonexistent_conversation(client, db_session, mock_redis):
"""测试标记不存在的会话已读
预期:返回404错误
"""
fake_id = str(uuid4())
response = await client.post(f"/api/conversations/{fake_id}/mark-read")
assert response.status_code == 404
# =============================================================================
# 测试用例:图片上传
# =============================================================================
@pytest.mark.asyncio
async def test_upload_image_within_limit(client, db_session, mock_redis):
"""测试图片上传 - 10MB以内
预期:成功上传,返回文件URL
"""
# 创建小图片数据(约50KB
image_data = b"\x89PNG\r\n\x1a\n" + b"fake_image_data" * 5000
files = {"file": ("test.png", image_data, "image/png")}
response = await client.post("/api/messages/image", files=files)
assert response.status_code == 200
data = response.json()
assert data.get("code") == 0
assert "url" in data.get("data", {})
@pytest.mark.asyncio
async def test_upload_image_exceeds_limit(client, db_session, mock_redis):
"""测试图片上传 - 超过10MB
预期:返回400错误
"""
# 创建大于10MB的数据
large_data = b"x" * (11 * 1024 * 1024) # 11MB
files = {"file": ("large.png", large_data, "image/png")}
response = await client.post("/api/messages/image", files=files)
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)
@pytest.mark.asyncio
async def test_upload_invalid_image_type(client, db_session, mock_redis):
"""测试上传不支持的图片格式
预期:返回400错误
"""
# 模拟不支持的格式
image_data = b"fake_image"
files = {"file": ("test.bmp", image_data, "image/bmp")}
response = await client.post("/api/messages/image", files=files)
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)
# =============================================================================
# 测试用例:文件上传
# =============================================================================
@pytest.mark.asyncio
async def test_upload_file_within_limit(client, db_session, mock_redis):
"""测试文件上传 - 10MB以内
预期:成功上传,返回文件URL
"""
# 创建小文件(约50KB
file_data = b"fake_file_content" * 5000
files = {"file": ("test.pdf", file_data, "application/pdf")}
response = await client.post("/api/messages/file", files=files)
assert response.status_code == 200
data = response.json()
assert data.get("code") == 0
assert "url" in data.get("data", {})
@pytest.mark.asyncio
async def test_upload_file_exceeds_limit(client, db_session, mock_redis):
"""测试文件上传 - 超过10MB
预期:返回400错误
"""
large_data = b"x" * (11 * 1024 * 1024) # 11MB
files = {"file": ("large.pdf", large_data, "application/pdf")}
response = await client.post("/api/messages/file", files=files)
assert response.status_code == 400 or (response.status_code == 200 and response.json().get("code") == 400)