189 lines
8.1 KiB
Python
189 lines
8.1 KiB
Python
|
|
# =============================================================================
|
||
|
|
# 企微IT智能服务台 — WebSocket 端点签名 + 错误码回归测试
|
||
|
|
# =============================================================================
|
||
|
|
# 背景(2026-06-21 事故):
|
||
|
|
# h5_websocket_endpoint 早期版本(2026-06-15 前)曾带一个多余 `request: Request`
|
||
|
|
# 参数,导致 FastAPI 启动时抛 "missing argument 'request'" / 客户端 WS 握手
|
||
|
|
# 直接失败、500 错误。前端 WS 连接直接失败,后端日志报错。
|
||
|
|
#
|
||
|
|
# 修复(2026-06-15):
|
||
|
|
# 移除 `request: Request` 参数(部分 Starlette 版本注入 Request 失败,
|
||
|
|
# 改用 `websocket.headers` 和 `websocket.query_params` 读取 header/query)
|
||
|
|
#
|
||
|
|
# 本测试目的:
|
||
|
|
# 1. 防止以后有人加回 `request: Request` 参数(回归保护)
|
||
|
|
# 2. 验证两个 endpoint 的参数签名(websocket 必须存在,request 不能有)
|
||
|
|
# 3. 验证 H5 WS endpoint 缺失 token 时返回 close code 4001(WS-01)
|
||
|
|
# 4. 验证 H5 WS endpoint token 不匹配 employee_id 时返回 close code 4001
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
import inspect
|
||
|
|
import uuid
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import pytest_asyncio
|
||
|
|
from fastapi import WebSocket
|
||
|
|
from starlette.websockets import WebSocketDisconnect
|
||
|
|
|
||
|
|
from app.api import ws as ws_module
|
||
|
|
from app.api.ws import h5_websocket_endpoint, websocket_endpoint
|
||
|
|
from app.main import create_app
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 签名回归测试
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
class TestWebSocketEndpointSignature:
|
||
|
|
"""WebSocket endpoint 参数签名回归保护。
|
||
|
|
|
||
|
|
历史 bug: 早期版本有 `request: Request` 参数导致 FastAPI 启动失败。
|
||
|
|
修复方案: 移除该参数,改用 websocket.headers/query_params 读取。
|
||
|
|
"""
|
||
|
|
|
||
|
|
def test_websocket_endpoint_has_no_request_param(self):
|
||
|
|
"""坐席端 endpoint 不能有 `request` 参数(防 missing argument 回归)。"""
|
||
|
|
sig = inspect.signature(websocket_endpoint)
|
||
|
|
assert "request" not in sig.parameters, (
|
||
|
|
"websocket_endpoint 不应有 request 参数,FastAPI WebSocket 路由只支持 "
|
||
|
|
"websocket + 路径参数。回归会导致 'missing argument request' 500 错误!"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_h5_websocket_endpoint_has_no_request_param(self):
|
||
|
|
"""H5 端 endpoint 不能有 `request` 参数(防 missing argument 回归)。"""
|
||
|
|
sig = inspect.signature(h5_websocket_endpoint)
|
||
|
|
assert "request" not in sig.parameters, (
|
||
|
|
"h5_websocket_endpoint 不应有 request 参数!回归会导致 'missing argument request' 500 错误!"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_websocket_endpoint_first_param_is_websocket(self):
|
||
|
|
"""坐席端 endpoint 第一个参数必须是 WebSocket 类型。"""
|
||
|
|
sig = inspect.signature(websocket_endpoint)
|
||
|
|
params = list(sig.parameters.values())
|
||
|
|
assert params[0].annotation is WebSocket, (
|
||
|
|
f"坐席端第一个参数必须是 WebSocket,实际是 {params[0].annotation}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_h5_websocket_endpoint_first_param_is_websocket(self):
|
||
|
|
"""H5 端 endpoint 第一个参数必须是 WebSocket 类型。"""
|
||
|
|
sig = inspect.signature(h5_websocket_endpoint)
|
||
|
|
params = list(sig.parameters.values())
|
||
|
|
assert params[0].annotation is WebSocket, (
|
||
|
|
f"H5 端第一个参数必须是 WebSocket,实际是 {params[0].annotation}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_ws_router_is_registered_in_app(self):
|
||
|
|
"""主应用必须注册 ws router(否则 /ws 路径 404)。"""
|
||
|
|
app = create_app()
|
||
|
|
ws_routes = [r for r in app.routes if getattr(r, "path", "").startswith("/ws")]
|
||
|
|
assert any("/ws/{agent_id}" in getattr(r, "path", "") for r in ws_routes), (
|
||
|
|
"坐席 WS 路由 /ws/{agent_id} 未注册"
|
||
|
|
)
|
||
|
|
assert any("/ws/h5/{employee_id}" in getattr(r, "path", "") for r in ws_routes), (
|
||
|
|
"H5 WS 路由 /ws/h5/{employee_id} 未注册"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# =============================================================================
|
||
|
|
# 运行时测试 — 验证 WS 鉴权逻辑
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
|
||
|
|
@pytest_asyncio.fixture
|
||
|
|
async def mock_redis_with_employee(mock_redis):
|
||
|
|
"""把 employee_id 注入 mock Redis,模拟已登录状态。"""
|
||
|
|
employee_id = f"emp_{uuid.uuid4().hex[:8]}"
|
||
|
|
token = f"tok_{uuid.uuid4().hex[:16]}"
|
||
|
|
await mock_redis.setex(f"employee:token:{token}", 86400, employee_id)
|
||
|
|
return employee_id, token
|
||
|
|
|
||
|
|
|
||
|
|
class TestH5WebSocketRuntime:
|
||
|
|
"""H5 WebSocket 运行时测试 — 验证 auth 错误码。
|
||
|
|
|
||
|
|
不依赖 create_app()(避免触发 PG 连接),直接用 ws.py 的 router 构造
|
||
|
|
独立 FastAPI 实例。这样既验证 endpoint 行为,又不需要任何外部服务。
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _build_ws_only_app(self):
|
||
|
|
"""构造只含 ws router 的 FastAPI 实例(无 DB/Redis 依赖)。"""
|
||
|
|
from fastapi import FastAPI
|
||
|
|
from app.api.ws import router as ws_router
|
||
|
|
app = FastAPI()
|
||
|
|
app.include_router(ws_router)
|
||
|
|
return app
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_h5_ws_missing_token_closes_with_4001(self):
|
||
|
|
"""缺 token 时,server 应 close(code=4001) — WS-01 安全要求。"""
|
||
|
|
from app.services.cache_service import cache_service
|
||
|
|
from starlette.testclient import TestClient
|
||
|
|
|
||
|
|
app = self._build_ws_only_app()
|
||
|
|
with pytest.MonkeyPatch.context() as mp:
|
||
|
|
async def fake_get(key):
|
||
|
|
return None # 模拟 token 不存在
|
||
|
|
mp.setattr(cache_service, "get", fake_get)
|
||
|
|
|
||
|
|
with TestClient(app) as client:
|
||
|
|
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||
|
|
with client.websocket_connect("/ws/h5/emp_test") as ws:
|
||
|
|
# 不带任何 token,期望 close code 4001
|
||
|
|
ws.receive_text()
|
||
|
|
# close code 应当是 4001(自定义未授权)
|
||
|
|
assert exc_info.value.code == 4001, (
|
||
|
|
f"缺 token 应关闭 4001,实际 {exc_info.value.code}"
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_h5_ws_token_employee_mismatch_closes_with_4001(self):
|
||
|
|
"""token 对应的 employee_id 与 URL 不一致时,close 4001。"""
|
||
|
|
from app.services.cache_service import cache_service
|
||
|
|
from starlette.testclient import TestClient
|
||
|
|
|
||
|
|
app = self._build_ws_only_app()
|
||
|
|
with pytest.MonkeyPatch.context() as mp:
|
||
|
|
async def fake_get(key):
|
||
|
|
return b"emp_real" # token 对应 emp_real
|
||
|
|
mp.setattr(cache_service, "get", fake_get)
|
||
|
|
|
||
|
|
with TestClient(app) as client:
|
||
|
|
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||
|
|
with client.websocket_connect(
|
||
|
|
"/ws/h5/emp_impostor?token=fake_token"
|
||
|
|
) as ws:
|
||
|
|
ws.receive_text()
|
||
|
|
assert exc_info.value.code == 4001, (
|
||
|
|
f"token-employee 不匹配应关闭 4001,实际 {exc_info.value.code}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestAgentWebSocketRuntime:
|
||
|
|
"""坐席 WebSocket 运行时测试 — 验证 auth 错误码。"""
|
||
|
|
|
||
|
|
def _build_ws_only_app(self):
|
||
|
|
from fastapi import FastAPI
|
||
|
|
from app.api.ws import router as ws_router
|
||
|
|
app = FastAPI()
|
||
|
|
app.include_router(ws_router)
|
||
|
|
return app
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_agent_ws_missing_token_closes_with_4001(self):
|
||
|
|
"""坐席端缺 token 关闭 4001。"""
|
||
|
|
from app.services.cache_service import cache_service
|
||
|
|
from starlette.testclient import TestClient
|
||
|
|
|
||
|
|
app = self._build_ws_only_app()
|
||
|
|
with pytest.MonkeyPatch.context() as mp:
|
||
|
|
async def fake_get(key):
|
||
|
|
return None
|
||
|
|
mp.setattr(cache_service, "get", fake_get)
|
||
|
|
|
||
|
|
with TestClient(app) as client:
|
||
|
|
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||
|
|
with client.websocket_connect("/ws/agent_test") as ws:
|
||
|
|
ws.receive_text()
|
||
|
|
assert exc_info.value.code == 4001
|