321 lines
14 KiB
Python
321 lines
14 KiB
Python
# =============================================================================
|
||
# 企微IT智能服务台 — WebSocket 端点
|
||
# =============================================================================
|
||
# 说明:提供 WebSocket 端点,供坐席前端和H5用户端建立长连接,实现实时推送。
|
||
# 核心功能:
|
||
# 1. 接受坐席的 WebSocket 连接请求(含 token 认证)— /ws/{agent_id}
|
||
# 2. 接受H5员工的 WebSocket 连接请求(含 token 认证)— /ws/h5/{employee_id}
|
||
# 3. 维持连接,监听客户端消息(主要是心跳 ping)
|
||
# 4. 连接断开时自动清理注册信息
|
||
# 安全(WS-01):
|
||
# 握手时从 query param 取 token → 查 Redis 验证 → 不通过则 close(code=4001)
|
||
# 防止未授权用户冒充坐席/员工建立 WS 连接
|
||
#
|
||
# 端点路径:
|
||
# - 坐席端:/ws/{agent_id}?token=xxx
|
||
# - H5员工端:/ws/h5/{employee_id}?token=xxx
|
||
# 为什么不挂 /api 前缀:WebSocket 不是 REST API,不走 Vite 的 /api 代理配置
|
||
# =============================================================================
|
||
|
||
import logging
|
||
|
||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||
from starlette.requests import Request
|
||
|
||
from app.services.ws_manager import manager as ws_manager
|
||
from app.services.cache_service import cache_service
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# WebSocket 路由器(不挂 /api 前缀,直接注册在应用根路径)
|
||
router = APIRouter()
|
||
|
||
# 认证失败时的 WebSocket 关闭码
|
||
# 4001 = 自定义码,表示"未授权"(4000+ 为应用自定义范围)
|
||
WS_CLOSE_UNAUTHORIZED = 4001
|
||
|
||
|
||
@router.websocket("/ws/{agent_id}")
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
agent_id: str,
|
||
request: Request,
|
||
) -> None:
|
||
"""坐席 WebSocket 端点主循环(含 WS-01 token 认证)。
|
||
|
||
做什么:
|
||
1. 从 Authorization header 获取 token(优先)或 query param(兼容)
|
||
2. 验证 token 有效性(查 Redis)
|
||
3. 验证 token 与 agent_id 一致性(防冒充)
|
||
4. 认证通过后接受连接,注册到 ConnectionManager
|
||
5. 进入消息接收循环,处理客户端发送的消息
|
||
6. 连接断开时清理注册信息
|
||
|
||
为什么需要 token 认证(WS-01):
|
||
- 之前 /ws/{agent_id} 无任何认证,任何人知道 URL 即可冒充任意坐席
|
||
- 攻击者可监听所有消息、发送伪造消息,是 P0 级安全漏洞
|
||
- 修复后,必须提供与 agent_id 匹配的有效 token 才能建立连接
|
||
|
||
安全改进(P0-#4):
|
||
- 优先从 Authorization: Bearer {token} header 获取 token
|
||
- 兼容从 ?token= URL 参数获取(向后兼容)
|
||
- 不再将 token 暴露在 URL 中,避免 access_log 泄露
|
||
|
||
Args:
|
||
websocket: FastAPI WebSocket 对象(框架自动注入)
|
||
agent_id: 坐席ID(从 URL 路径参数获取)
|
||
request: Starlette Request(用于获取 header)
|
||
"""
|
||
# ======================================================================
|
||
# WS-01: Token 认证(从 subprotocol / header / query 获取)
|
||
# ======================================================================
|
||
|
||
# 步骤1: 优先从 Sec-WebSocket-Protocol (subprotocol) 获取 token,其次从 Authorization header,最后从 query(向后兼容)
|
||
# 格式: Sec-WebSocket-Protocol: bearer.{token}
|
||
# 说明: 浏览器原生 WebSocket API 不支持 headers 参数,但支持 subprotocols (第2参数数组)
|
||
# 前端用 new WebSocket(url, ["bearer.{token}"]) 传递,服务端从 sec-websocket-protocol 头读取
|
||
subprotocol = request.headers.get("sec-websocket-protocol", "")
|
||
if subprotocol.startswith("bearer."):
|
||
token = subprotocol[7:] # 去掉 "bearer." 前缀
|
||
else:
|
||
# 其次从 Authorization header 获取
|
||
auth_header = request.headers.get("Authorization", "")
|
||
if auth_header.startswith("Bearer "):
|
||
token = auth_header[7:] # 去掉 "Bearer " 前缀
|
||
else:
|
||
# 向后兼容:从 query param 获取(即将废弃)
|
||
token = request.query_params.get("token", "")
|
||
|
||
# 步骤2: 检查 token 是否为空
|
||
if not token:
|
||
# 先 accept 再 close,否则客户端收不到关闭帧
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Missing token")
|
||
logger.warning(f"WebSocket 拒绝连接: agent_id={agent_id}, 原因=缺少token")
|
||
return
|
||
|
||
# 步骤3: 从 Redis 查询 token 对应的坐席信息
|
||
# Redis 中存储格式: agent:token:{token} -> agent_user_id
|
||
# (与坐席登录 API /api/agents/login 存储格式一致)
|
||
try:
|
||
stored_agent_id = await cache_service.get(f"agent:token:{token}")
|
||
except Exception as e:
|
||
# Redis 不可用时必须拒绝连接:token 验证依赖 Redis,无法验证身份
|
||
# 如果降级放行,攻击者可在 Redis 故障时用任意 agent_id 冒充坐席
|
||
logger.error(f"Redis 查询失败,拒绝 WS 连接: agent_id={agent_id}, error={e}")
|
||
await websocket.accept()
|
||
await websocket.close(
|
||
code=WS_CLOSE_UNAUTHORIZED,
|
||
reason="Authentication service unavailable"
|
||
)
|
||
return
|
||
|
||
# 步骤4: 验证 token 与 agent_id 一致性
|
||
if not stored_agent_id:
|
||
# token 不存在(已过期或伪造)
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Invalid or expired token")
|
||
logger.warning(f"WebSocket 拒绝连接: agent_id={agent_id}, 原因=token无效或已过期")
|
||
return
|
||
|
||
if stored_agent_id != agent_id:
|
||
# token 对应的坐席与请求的 agent_id 不匹配(冒充)
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Token-agent mismatch")
|
||
logger.warning(
|
||
f"WebSocket 拒绝连接: agent_id={agent_id}, "
|
||
f"原因=token对应坐席{stored_agent_id}与请求不匹配"
|
||
)
|
||
return
|
||
|
||
# ======================================================================
|
||
# 认证通过,建立连接
|
||
# ======================================================================
|
||
|
||
# 注册连接(内部会调用 websocket.accept())
|
||
await ws_manager.connect(agent_id, websocket)
|
||
logger.info(f"坐席 WebSocket 连接已认证: agent_id={agent_id}")
|
||
|
||
try:
|
||
# 消息接收循环
|
||
# 保持连接打开,监听客户端发来的消息
|
||
# 即使客户端不发消息,这个循环也必须保持,否则连接会关闭
|
||
while True:
|
||
# 等待接收客户端消息(阻塞等待)
|
||
data = await websocket.receive_json()
|
||
|
||
# 处理心跳 ping
|
||
# 前端每 30 秒发送一次 ping,后端回复 pong
|
||
# 作用:检测连接是否存活,防止中间代理(如 Nginx)因超时断开连接
|
||
if data.get("type") == "ping":
|
||
await websocket.send_json({"type": "pong"})
|
||
logger.debug(f"WebSocket 心跳: agent_id={agent_id}")
|
||
|
||
# 处理输入指示器 typing 事件
|
||
# 前端在用户输入时发送 typing 事件,后端广播给同一会话的其他参与者
|
||
elif data.get("type") == "typing":
|
||
conversation_id = data.get("conversation_id")
|
||
sender_name = data.get("sender_name", agent_id)
|
||
if conversation_id:
|
||
# 广播给所有坐席(包含 sender_type 和 sender_id,
|
||
# 前端可据此过滤掉自己的 typing 事件)
|
||
await ws_manager.broadcast({
|
||
"type": "typing",
|
||
"data": {
|
||
"conversation_id": conversation_id,
|
||
"sender_id": agent_id,
|
||
"sender_name": sender_name,
|
||
"sender_type": "agent",
|
||
}
|
||
})
|
||
|
||
else:
|
||
# 未来可扩展处理其他类型的客户端消息
|
||
logger.debug(
|
||
f"WebSocket 收到未知消息: agent_id={agent_id}, "
|
||
f"type={data.get('type', 'unknown')}"
|
||
)
|
||
|
||
except WebSocketDisconnect:
|
||
# 客户端主动断开连接(正常行为)
|
||
# 清理 ConnectionManager 中的注册信息
|
||
ws_manager.disconnect(agent_id)
|
||
logger.info(f"坐席断开 WebSocket 连接: agent_id={agent_id}")
|
||
|
||
except Exception as e:
|
||
# 其他异常(如网络错误、JSON 解析错误等)
|
||
# 确保注册信息被清理
|
||
ws_manager.disconnect(agent_id)
|
||
logger.warning(f"WebSocket 异常断开: agent_id={agent_id}, error={e}")
|
||
|
||
|
||
# ==========================================================================
|
||
# H5员工 WebSocket 端点
|
||
# ==========================================================================
|
||
|
||
@router.websocket("/ws/h5/{employee_id}")
|
||
async def h5_websocket_endpoint(
|
||
websocket: WebSocket,
|
||
employee_id: str,
|
||
request: Request,
|
||
) -> None:
|
||
"""H5员工 WebSocket 端点主循环(含 token 认证)。
|
||
|
||
做什么:
|
||
1. 从 Authorization header 获取 token(优先从)或 query param(兼容)
|
||
2. 验证 employee token 有效性(查 Redis)
|
||
3. 验证 token 与 employee_id 一致性(防冒充)
|
||
4. 认证通过后接受连接,注册到 ConnectionManager 的员工连接表
|
||
5. 进入消息接收循环,处理心跳 ping
|
||
6. 连接断开时清理注册信息
|
||
|
||
为什么需要 H5 WS 连接:
|
||
- H5员工需要实时接收参与者变更事件(新参与者加入、有人退出等)
|
||
- 当前仅通过 3 秒轮询获取更新,实时性不足
|
||
- WS 推送 + 轮询降级,双通道保证消息可达
|
||
|
||
安全改进(P0-#4):
|
||
- 优先从 Authorization: Bearer {token} header 获取 token
|
||
- 兼容从 ?token= URL 参数获取(向后兼容)
|
||
|
||
认证机制(与坐席端一致):
|
||
- Redis 中存储格式: employee:token:{token} -> employee_id
|
||
- (与H5登录 API /api/h5/mock-login 存储格式一致)
|
||
- token 缺失、无效、过期、与 employee_id 不匹配均拒绝连接
|
||
|
||
Args:
|
||
websocket: FastAPI WebSocket 对象(框架自动注入)
|
||
employee_id: 员工企微 UserID(从 URL 路径参数获取)
|
||
request: Starlette Request(用于获取 header)
|
||
"""
|
||
# ======================================================================
|
||
# Token 认证(从 subprotocol / header / query 获取)
|
||
# ======================================================================
|
||
|
||
# 步骤1: 优先从 Sec-WebSocket-Protocol (subprotocol) 获取 token,其次从 Authorization header,最后从 query(向后兼容)
|
||
# 格式: Sec-WebSocket-Protocol: bearer.{token}
|
||
subprotocol = request.headers.get("sec-websocket-protocol", "")
|
||
if subprotocol.startswith("bearer."):
|
||
token = subprotocol[7:] # 去掉 "bearer." 前缀
|
||
else:
|
||
# 其次从 Authorization header 获取
|
||
auth_header = request.headers.get("Authorization", "")
|
||
if auth_header.startswith("Bearer "):
|
||
token = auth_header[7:] # 去掉 "Bearer " 前缀
|
||
else:
|
||
# 向后兼容:从 query param 获取(即将废弃)
|
||
token = request.query_params.get("token", "")
|
||
|
||
# 步骤2: 检查 token 是否为空
|
||
if not token:
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Missing token")
|
||
logger.warning(f"H5 WebSocket 拒绝连接: employee_id={employee_id}, 原因=缺少token")
|
||
return
|
||
|
||
# 步骤3: 从 Redis 查询 token 对应的员工信息
|
||
# Redis 中存储格式: employee:token:{token} -> employee_id
|
||
# (与H5登录 API /api/h5/mock-login 存储格式一致)
|
||
try:
|
||
stored_employee_id = await cache_service.get(f"employee:token:{token}")
|
||
except Exception as e:
|
||
# Redis 不可用时必须拒绝连接(与坐席端一致的安全策略)
|
||
logger.error(f"Redis 查询失败,拒绝 H5 WS 连接: employee_id={employee_id}, error={e}")
|
||
await websocket.accept()
|
||
await websocket.close(
|
||
code=WS_CLOSE_UNAUTHORIZED,
|
||
reason="Authentication service unavailable"
|
||
)
|
||
return
|
||
|
||
# 步骤4: 验证 token 与 employee_id 一致性
|
||
if not stored_employee_id:
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Invalid or expired token")
|
||
logger.warning(f"H5 WebSocket 拒绝连接: employee_id={employee_id}, 原因=token无效或已过期")
|
||
return
|
||
|
||
if stored_employee_id != employee_id:
|
||
await websocket.accept()
|
||
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Token-employee mismatch")
|
||
logger.warning(
|
||
f"H5 WebSocket 拒绝连接: employee_id={employee_id}, "
|
||
f"原因=token对应员工{stored_employee_id}与请求不匹配"
|
||
)
|
||
return
|
||
|
||
# ======================================================================
|
||
# 认证通过,建立连接
|
||
# ======================================================================
|
||
|
||
# 注册员工连接(内部会调用 websocket.accept())
|
||
await ws_manager.connect_employee(employee_id, websocket)
|
||
logger.info(f"H5员工 WebSocket 连接已认证: employee_id={employee_id}")
|
||
|
||
try:
|
||
# 消息接收循环
|
||
# H5员工端目前只发送心跳 ping,不需要发送 typing 等事件
|
||
while True:
|
||
data = await websocket.receive_json()
|
||
|
||
# 处理心跳 ping
|
||
if data.get("type") == "ping":
|
||
await websocket.send_json({"type": "pong"})
|
||
logger.debug(f"H5 WebSocket 心跳: employee_id={employee_id}")
|
||
|
||
else:
|
||
logger.debug(
|
||
f"H5 WebSocket 收到未知消息: employee_id={employee_id}, "
|
||
f"type={data.get('type', 'unknown')}"
|
||
)
|
||
|
||
except WebSocketDisconnect:
|
||
# 客户端主动断开连接
|
||
ws_manager.disconnect_employee(employee_id)
|
||
logger.info(f"H5员工断开 WebSocket 连接: employee_id={employee_id}")
|
||
|
||
except Exception as e:
|
||
# 其他异常
|
||
ws_manager.disconnect_employee(employee_id)
|
||
logger.warning(f"H5 WebSocket 异常断开: employee_id={employee_id}, error={e}")
|