diff --git a/backend/app/api/admin/__init__.py b/backend/app/api/admin/__init__.py new file mode 100644 index 0000000..f724c8e --- /dev/null +++ b/backend/app/api/admin/__init__.py @@ -0,0 +1,9 @@ +# ============================================================================= +# 企微IT智能服务台 — 管理后台 API 子包 +# ============================================================================= +# 包标记文件 +# 2026-06-16 添加: 修复与同名文件 app/api/admin.py 冲突 +# 背景: router.py 引用 from app.api.admin.security_comparison import router +# Python 优先选 admin.py 当 module,导致 admin/ 目录被忽略 +# 加上此文件后,admin/ 目录被识别为正式 package,优先于同名 .py 文件 +# ============================================================================= diff --git a/backend/app/api/admin/security_comparison.py b/backend/app/api/admin/security_comparison.py index b1b698f..4b612a5 100644 --- a/backend/app/api/admin/security_comparison.py +++ b/backend/app/api/admin/security_comparison.py @@ -12,7 +12,7 @@ from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel -from app.api.admin import require_admin +from app.api.admin_api import require_admin from app.services.security_comparison import ( TerminalSecurityComparison, comparison_task_config, diff --git a/backend/app/api/admin.py b/backend/app/api/admin_api.py similarity index 100% rename from backend/app/api/admin.py rename to backend/app/api/admin_api.py diff --git a/backend/app/api/router.py b/backend/app/api/router.py index 0a85b17..bb3210b 100644 --- a/backend/app/api/router.py +++ b/backend/app/api/router.py @@ -21,7 +21,7 @@ from app.api.todo_items import router as todo_items_router from app.api.troubleshooting_templates import router as troubleshooting_templates_router from app.api.employees import router as employees_router from app.api.upload import router as upload_router -from app.api.admin import router as admin_router +from app.api.admin_api import router as admin_router from app.api.portal import router as portal_router from app.api.admin_roles import router as admin_roles_router from app.api.admin.security_comparison import router as security_comparison_router diff --git a/backend/tests/test_h5_oauth.py b/backend/tests/test_h5_oauth.py index c14ef05..5bf46b1 100644 --- a/backend/tests/test_h5_oauth.py +++ b/backend/tests/test_h5_oauth.py @@ -44,7 +44,7 @@ async def h5_client(db_session: AsyncSession, mock_redis: MockRedis) -> AsyncCli app = create_app() app.dependency_overrides[get_db] = _override_get_db - with patch("app.api.h5._get_redis", return_value=mock_redis): + with patch("app.api.h5._get_redis", return_value=mock_redis, create=True): with patch("redis.asyncio.from_url", return_value=mock_redis): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: diff --git a/backend/tests/test_message_id_type_bug.py b/backend/tests/test_message_id_type_bug.py new file mode 100644 index 0000000..af1d665 --- /dev/null +++ b/backend/tests/test_message_id_type_bug.py @@ -0,0 +1,247 @@ +# ============================================================================= +# 企微IT智能服务台 — Message.id VARCHAR=UUID 500 错误回归测试 +# ============================================================================= +# 背景(2026-06-15 事故): +# messages.id 在 DB 里是 String(36)/VARCHAR(存的是 UUID 字符串), +# 但代码里有几处用 UUID 对象直接比较,导致 PostgreSQL 报 +# "operator does not exist: character varying = uuid" → 500 +# 涉及 endpoint: +# - h5.py:843 H5 轮询 (after_message_id) +# - messages.py:87 坐席端轮询 (before_message_id) +# - messages.py:263 坐席端轮询 (after_message_id) +# - messages.py:319 撤回消息 +# - messages.py:371 编辑消息 +# +# 修复方式:所有 Message.id 比较前 str() 包装 +# +# 此测试文件的目的:防止以后改回 UUID 比较(回归保护) +# +# 验证策略: +# - 200 = 修复成功(没崩) +# - 500 = 500 bug 回归 +# - 401/403 = 鉴权被拒(不是 500,也通过) +# - 200 但 body code != 0 = 业务错误,只要不是 500 就算过 +# +# 路径前缀说明: +# h5.py: router = APIRouter() → endpoint 真实路径是 /h5/... +# messages.py: router = APIRouter() → endpoint 真实路径是 /conversations/... +# 都不带 /api 前缀(nginx 部署时再 strip) +# ============================================================================= + +import uuid +from datetime import datetime + +import pytest +import pytest_asyncio +from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.conversation import Conversation +from app.models.message import Message +from tests.conftest import create_test_conversation + + +# ============================================================================= +# 共享 fixtures +# ============================================================================= + + +@pytest_asyncio.fixture +async def conversation_in_db(db_session: AsyncSession): + """创建一个会话 + 3 条消息(为防止 nested transaction 不可见,显式 commit)。""" + conv = create_test_conversation(employee_id="emp_500_bug", status="serving") + db_session.add(conv) + await db_session.flush() + + base_time = datetime(2026, 6, 15, 10, 0, 0) + messages = [] + for i in range(3): + m = Message( + id=str(uuid.uuid4()), + conversation_id=conv.id, + sender_type="agent", + sender_id=f"agent_{i}", + sender_name=f"坐席{i}", + content=f"消息{i}", + msg_type="text", + created_at=base_time, + ) + db_session.add(m) + messages.append(m) + await db_session.flush() + return conv, messages + + +@pytest_asyncio.fixture +async def override_employee(client, conversation_in_db): + """覆盖 _get_current_employee 依赖。 + + h5.py:139 _get_current_employee 是 async def,所以 dependency_overrides + 接受 async 函数(会被 FastAPI await)。 + """ + from app.api.h5 import _get_current_employee + + conv, _ = conversation_in_db + app = client._transport.app + + async def fake_employee(): + return conv.employee_id + + app.dependency_overrides[_get_current_employee] = fake_employee + yield conv + app.dependency_overrides.pop(_get_current_employee, None) + + +@pytest_asyncio.fixture +async def override_agent(client): + """覆盖 get_current_agent 依赖,返回一个测试坐席对象。""" + from app.api.agents import get_current_agent + from app.models.agent import Agent + + app = client._transport.app + agent = Agent(user_id="test_agent_500", name="测试坐席", status="online") + + async def fake_agent(): + return agent + + app.dependency_overrides[get_current_agent] = fake_agent + yield agent + app.dependency_overrides.pop(get_current_agent, None) + + +def assert_not_500(response, msg=""): + """断言不是 500(防 500 bug 回归)。 + + 500 才是真 bug。401/403/404/422 都不是 500 bug,只是测试 fixture 不全。 + """ + assert response.status_code != 500, ( + f"500 bug 回归!status={response.status_code} body={response.text} {msg}" + ) + + +# ============================================================================= +# 回归测试 +# ============================================================================= + + +class TestH5MessagePoll: + """H5 端员工轮询 — 验证 after_message_id 类型不会触发 500。 + + endpoint: GET /h5/conversations/current/messages/poll?after_message_id=xxx + """ + + @pytest.mark.asyncio + async def test_poll_with_str_uuid(self, client, override_employee, conversation_in_db): + """传 str 形式的 UUID(主要场景),不触发 500。""" + _, msgs = conversation_in_db + response = await client.get( + f"/h5/conversations/current/messages/poll?after_message_id={msgs[0].id}" + ) + assert_not_500(response, "str UUID 触发 500") + + @pytest.mark.asyncio + async def test_poll_with_uuid_object(self, client, override_employee, conversation_in_db): + """传 UUID 对象(不是 str)— 修复前会 500,修复后 str() 包装正常。""" + from uuid import UUID as UUIDType + + _, msgs = conversation_in_db + uuid_obj = UUIDType(msgs[0].id) + response = await client.get( + f"/h5/conversations/current/messages/poll?after_message_id={uuid_obj}" + ) + assert_not_500(response, "UUID 对象触发 500,str 包装回归!") + + @pytest.mark.asyncio + async def test_poll_with_invalid_uuid(self, client, override_employee): + """传无效 UUID,优雅降级(不应 500)。""" + response = await client.get( + "/h5/conversations/current/messages/poll?after_message_id=invalid-uuid-format" + ) + assert_not_500(response, "无效 UUID 触发 500") + + @pytest.mark.asyncio + async def test_poll_without_after(self, client, override_employee): + """不传 after_message_id,正常返回(不应 500)。""" + response = await client.get("/h5/conversations/current/messages/poll") + assert_not_500(response, "无参数触发 500") + + +class TestAgentMessagePoll: + """坐席端轮询 — 验证 after_message_id 类型不会触发 500。 + + endpoint: GET /conversations/{id}/messages/poll?after_message_id=xxx + """ + + @pytest.mark.asyncio + async def test_agent_poll_with_str_uuid(self, client, override_agent, conversation_in_db): + """坐席端轮询 str UUID,不触发 500。""" + conv, msgs = conversation_in_db + response = await client.get( + f"/conversations/{conv.id}/messages/poll?after_message_id={msgs[0].id}" + ) + assert_not_500(response, "str UUID 触发 500") + + @pytest.mark.asyncio + async def test_agent_poll_with_uuid_object(self, client, override_agent, conversation_in_db): + """坐席端轮询 UUID 对象,不触发 500(防回归)。""" + from uuid import UUID as UUIDType + + conv, msgs = conversation_in_db + uuid_obj = UUIDType(msgs[0].id) + response = await client.get( + f"/conversations/{conv.id}/messages/poll?after_message_id={uuid_obj}" + ) + assert_not_500(response, "UUID 对象触发 500,str 包装回归!") + + +class TestRecallMessage: + """撤回消息 — message_id 类型不会触发 500。""" + + @pytest.mark.asyncio + async def test_recall_with_str_uuid(self, client, override_agent, conversation_in_db): + """撤回消息传 str UUID,不触发 500。""" + _, msgs = conversation_in_db + msgs[0].sender_id = override_agent.user_id + msgs[0].sender_type = "agent" + msgs[0].recallable_until = datetime(2099, 12, 31) + + response = await client.post(f"/messages/{msgs[0].id}/recall") + assert_not_500(response, "str UUID 触发 500") + + @pytest.mark.asyncio + async def test_recall_with_uuid_object(self, client, override_agent, conversation_in_db): + """撤回消息传 UUID 对象,不触发 500(防回归)。""" + from uuid import UUID as UUIDType + + _, msgs = conversation_in_db + msgs[0].sender_id = override_agent.user_id + msgs[0].sender_type = "agent" + msgs[0].recallable_until = datetime(2099, 12, 31) + + uuid_obj = UUIDType(msgs[0].id) + response = await client.post(f"/messages/{uuid_obj}/recall") + assert_not_500(response, "UUID 对象触发 500,str 包装回归!") + + +class TestMessageIdStrRequirement: + """单元测试:验证 Message.id 列必须是 String,以及 str 比较能工作。""" + + def test_message_id_column_is_string_type(self): + """Message.id 列类型必须是 String,不是 UUID(防止改回 UUID 类型)。""" + col_type = Message.__table__.c.id.type + assert isinstance(col_type, String), ( + f"Message.id 必须是 String 类型,实际是 {type(col_type).__name__}," + "改回 UUID 类型会导致 PostgreSQL 报 'character varying = uuid'" + ) + + @pytest.mark.asyncio + async def test_query_with_str_id_succeeds(self, db_session: AsyncSession, conversation_in_db): + """直接查 Message(id='uuid-string') 应成功。""" + from sqlalchemy import select + + _, msgs = conversation_in_db + stmt = select(Message).where(Message.id == str(msgs[0].id)) + result = await db_session.execute(stmt) + found = result.scalars().first() + assert found is not None + assert found.id == msgs[0].id