Files
wecom_it_smart_desk/backend/tests/test_ws_endpoints.py
T

189 lines
8.1 KiB
Python
Raw Normal View History

# =============================================================================
# 企微IT智能服务台 — WebSocket 端点签名 + 错误码回归测试
# =============================================================================
# 背景(2026-06-21 事故):
# h5_websocket_endpoint 早期版本(2026-06-15 前)曾带一个多余 `request: Request`
# 参数,导致 FastAPI 启动时抛 "missing argument 'request'" / 客户端 WS 握手
# 直接失败、500 错误。前端 WS 连接直接失败,后端日志报错。
#
# 修复(2026-06-15):
# 移除 `request: Request` 参数(部分 Starlette 版本注入 Request 失败,
# 改用 `websocket.headers` 和 `websocket.query_params` 读取 header/query)
#
# 本测试目的:
# 1. 防止以后有人加回 `request: Request` 参数(回归保护)
# 2. 验证两个 endpoint 的参数签名(websocket 必须存在,request 不能有)
# 3. 验证 H5 WS endpoint 缺失 token 时返回 close code 4001(WS-01)
# 4. 验证 H5 WS endpoint token 不匹配 employee_id 时返回 close code 4001
# =============================================================================
import inspect
import uuid
import pytest
import pytest_asyncio
from fastapi import WebSocket
from starlette.websockets import WebSocketDisconnect
from app.api import ws as ws_module
from app.api.ws import h5_websocket_endpoint, websocket_endpoint
from app.main import create_app
# =============================================================================
# 签名回归测试
# =============================================================================
class TestWebSocketEndpointSignature:
"""WebSocket endpoint 参数签名回归保护。
历史 bug: 早期版本有 `request: Request` 参数导致 FastAPI 启动失败。
修复方案: 移除该参数,改用 websocket.headers/query_params 读取。
"""
def test_websocket_endpoint_has_no_request_param(self):
"""坐席端 endpoint 不能有 `request` 参数(防 missing argument 回归)。"""
sig = inspect.signature(websocket_endpoint)
assert "request" not in sig.parameters, (
"websocket_endpoint 不应有 request 参数,FastAPI WebSocket 路由只支持 "
"websocket + 路径参数。回归会导致 'missing argument request' 500 错误!"
)
def test_h5_websocket_endpoint_has_no_request_param(self):
"""H5 端 endpoint 不能有 `request` 参数(防 missing argument 回归)。"""
sig = inspect.signature(h5_websocket_endpoint)
assert "request" not in sig.parameters, (
"h5_websocket_endpoint 不应有 request 参数!回归会导致 'missing argument request' 500 错误!"
)
def test_websocket_endpoint_first_param_is_websocket(self):
"""坐席端 endpoint 第一个参数必须是 WebSocket 类型。"""
sig = inspect.signature(websocket_endpoint)
params = list(sig.parameters.values())
assert params[0].annotation is WebSocket, (
f"坐席端第一个参数必须是 WebSocket,实际是 {params[0].annotation}"
)
def test_h5_websocket_endpoint_first_param_is_websocket(self):
"""H5 端 endpoint 第一个参数必须是 WebSocket 类型。"""
sig = inspect.signature(h5_websocket_endpoint)
params = list(sig.parameters.values())
assert params[0].annotation is WebSocket, (
f"H5 端第一个参数必须是 WebSocket,实际是 {params[0].annotation}"
)
def test_ws_router_is_registered_in_app(self):
"""主应用必须注册 ws router(否则 /ws 路径 404)。"""
app = create_app()
ws_routes = [r for r in app.routes if getattr(r, "path", "").startswith("/ws")]
assert any("/ws/{agent_id}" in getattr(r, "path", "") for r in ws_routes), (
"坐席 WS 路由 /ws/{agent_id} 未注册"
)
assert any("/ws/h5/{employee_id}" in getattr(r, "path", "") for r in ws_routes), (
"H5 WS 路由 /ws/h5/{employee_id} 未注册"
)
# =============================================================================
# 运行时测试 — 验证 WS 鉴权逻辑
# =============================================================================
@pytest_asyncio.fixture
async def mock_redis_with_employee(mock_redis):
"""把 employee_id 注入 mock Redis,模拟已登录状态。"""
employee_id = f"emp_{uuid.uuid4().hex[:8]}"
token = f"tok_{uuid.uuid4().hex[:16]}"
await mock_redis.setex(f"employee:token:{token}", 86400, employee_id)
return employee_id, token
class TestH5WebSocketRuntime:
"""H5 WebSocket 运行时测试 — 验证 auth 错误码。
不依赖 create_app()(避免触发 PG 连接),直接用 ws.py 的 router 构造
独立 FastAPI 实例。这样既验证 endpoint 行为,又不需要任何外部服务。
"""
def _build_ws_only_app(self):
"""构造只含 ws router 的 FastAPI 实例(无 DB/Redis 依赖)。"""
from fastapi import FastAPI
from app.api.ws import router as ws_router
app = FastAPI()
app.include_router(ws_router)
return app
@pytest.mark.asyncio
async def test_h5_ws_missing_token_closes_with_4001(self):
"""缺 token 时,server 应 close(code=4001) — WS-01 安全要求。"""
from app.services.cache_service import cache_service
from starlette.testclient import TestClient
app = self._build_ws_only_app()
with pytest.MonkeyPatch.context() as mp:
async def fake_get(key):
return None # 模拟 token 不存在
mp.setattr(cache_service, "get", fake_get)
with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect) as exc_info:
with client.websocket_connect("/ws/h5/emp_test") as ws:
# 不带任何 token,期望 close code 4001
ws.receive_text()
# close code 应当是 4001(自定义未授权)
assert exc_info.value.code == 4001, (
f"缺 token 应关闭 4001,实际 {exc_info.value.code}"
)
@pytest.mark.asyncio
async def test_h5_ws_token_employee_mismatch_closes_with_4001(self):
"""token 对应的 employee_id 与 URL 不一致时,close 4001。"""
from app.services.cache_service import cache_service
from starlette.testclient import TestClient
app = self._build_ws_only_app()
with pytest.MonkeyPatch.context() as mp:
async def fake_get(key):
return b"emp_real" # token 对应 emp_real
mp.setattr(cache_service, "get", fake_get)
with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect) as exc_info:
with client.websocket_connect(
"/ws/h5/emp_impostor?token=fake_token"
) as ws:
ws.receive_text()
assert exc_info.value.code == 4001, (
f"token-employee 不匹配应关闭 4001,实际 {exc_info.value.code}"
)
class TestAgentWebSocketRuntime:
"""坐席 WebSocket 运行时测试 — 验证 auth 错误码。"""
def _build_ws_only_app(self):
from fastapi import FastAPI
from app.api.ws import router as ws_router
app = FastAPI()
app.include_router(ws_router)
return app
@pytest.mark.asyncio
async def test_agent_ws_missing_token_closes_with_4001(self):
"""坐席端缺 token 关闭 4001。"""
from app.services.cache_service import cache_service
from starlette.testclient import TestClient
app = self._build_ws_only_app()
with pytest.MonkeyPatch.context() as mp:
async def fake_get(key):
return None
mp.setattr(cache_service, "get", fake_get)
with TestClient(app) as client:
with pytest.raises(WebSocketDisconnect) as exc_info:
with client.websocket_connect("/ws/agent_test") as ws:
ws.receive_text()
assert exc_info.value.code == 4001