Files

321 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — WebSocket 端点
# =============================================================================
# 说明:提供 WebSocket 端点,供坐席前端和H5用户端建立长连接,实现实时推送。
# 核心功能:
# 1. 接受坐席的 WebSocket 连接请求(含 token 认证)— /ws/{agent_id}
# 2. 接受H5员工的 WebSocket 连接请求(含 token 认证)— /ws/h5/{employee_id}
# 3. 维持连接,监听客户端消息(主要是心跳 ping)
# 4. 连接断开时自动清理注册信息
# 安全(WS-01):
# 握手时从 query param 取 token → 查 Redis 验证 → 不通过则 close(code=4001)
# 防止未授权用户冒充坐席/员工建立 WS 连接
#
# 端点路径:
# - 坐席端:/ws/{agent_id}?token=xxx
# - H5员工端:/ws/h5/{employee_id}?token=xxx
# 为什么不挂 /api 前缀:WebSocket 不是 REST API,不走 Vite 的 /api 代理配置
# =============================================================================
import logging
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from starlette.requests import Request
from app.services.ws_manager import manager as ws_manager
from app.services.cache_service import cache_service
logger = logging.getLogger(__name__)
# WebSocket 路由器(不挂 /api 前缀,直接注册在应用根路径)
router = APIRouter()
# 认证失败时的 WebSocket 关闭码
# 4001 = 自定义码,表示"未授权"(4000+ 为应用自定义范围)
WS_CLOSE_UNAUTHORIZED = 4001
@router.websocket("/ws/{agent_id}")
async def websocket_endpoint(
websocket: WebSocket,
agent_id: str,
request: Request,
) -> None:
"""坐席 WebSocket 端点主循环(含 WS-01 token 认证)。
做什么:
1. 从 Authorization header 获取 token(优先)或 query param(兼容)
2. 验证 token 有效性(查 Redis
3. 验证 token 与 agent_id 一致性(防冒充)
4. 认证通过后接受连接,注册到 ConnectionManager
5. 进入消息接收循环,处理客户端发送的消息
6. 连接断开时清理注册信息
为什么需要 token 认证(WS-01):
- 之前 /ws/{agent_id} 无任何认证,任何人知道 URL 即可冒充任意坐席
- 攻击者可监听所有消息、发送伪造消息,是 P0 级安全漏洞
- 修复后,必须提供与 agent_id 匹配的有效 token 才能建立连接
安全改进(P0-#4):
- 优先从 Authorization: Bearer {token} header 获取 token
- 兼容从 ?token= URL 参数获取(向后兼容)
- 不再将 token 暴露在 URL 中,避免 access_log 泄露
Args:
websocket: FastAPI WebSocket 对象(框架自动注入)
agent_id: 坐席ID(从 URL 路径参数获取)
request: Starlette Request(用于获取 header
"""
# ======================================================================
# WS-01: Token 认证(从 subprotocol / header / query 获取)
# ======================================================================
# 步骤1: 优先从 Sec-WebSocket-Protocol (subprotocol) 获取 token,其次从 Authorization header,最后从 query(向后兼容)
# 格式: Sec-WebSocket-Protocol: bearer.{token}
# 说明: 浏览器原生 WebSocket API 不支持 headers 参数,但支持 subprotocols (第2参数数组)
# 前端用 new WebSocket(url, ["bearer.{token}"]) 传递,服务端从 sec-websocket-protocol 头读取
subprotocol = request.headers.get("sec-websocket-protocol", "")
if subprotocol.startswith("bearer."):
token = subprotocol[7:] # 去掉 "bearer." 前缀
else:
# 其次从 Authorization header 获取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:] # 去掉 "Bearer " 前缀
else:
# 向后兼容:从 query param 获取(即将废弃)
token = request.query_params.get("token", "")
# 步骤2: 检查 token 是否为空
if not token:
# 先 accept 再 close,否则客户端收不到关闭帧
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Missing token")
logger.warning(f"WebSocket 拒绝连接: agent_id={agent_id}, 原因=缺少token")
return
# 步骤3: 从 Redis 查询 token 对应的坐席信息
# Redis 中存储格式: agent:token:{token} -> agent_user_id
# (与坐席登录 API /api/agents/login 存储格式一致)
try:
stored_agent_id = await cache_service.get(f"agent:token:{token}")
except Exception as e:
# Redis 不可用时必须拒绝连接:token 验证依赖 Redis,无法验证身份
# 如果降级放行,攻击者可在 Redis 故障时用任意 agent_id 冒充坐席
logger.error(f"Redis 查询失败,拒绝 WS 连接: agent_id={agent_id}, error={e}")
await websocket.accept()
await websocket.close(
code=WS_CLOSE_UNAUTHORIZED,
reason="Authentication service unavailable"
)
return
# 步骤4: 验证 token 与 agent_id 一致性
if not stored_agent_id:
# token 不存在(已过期或伪造)
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Invalid or expired token")
logger.warning(f"WebSocket 拒绝连接: agent_id={agent_id}, 原因=token无效或已过期")
return
if stored_agent_id != agent_id:
# token 对应的坐席与请求的 agent_id 不匹配(冒充)
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Token-agent mismatch")
logger.warning(
f"WebSocket 拒绝连接: agent_id={agent_id}, "
f"原因=token对应坐席{stored_agent_id}与请求不匹配"
)
return
# ======================================================================
# 认证通过,建立连接
# ======================================================================
# 注册连接(内部会调用 websocket.accept()
await ws_manager.connect(agent_id, websocket)
logger.info(f"坐席 WebSocket 连接已认证: agent_id={agent_id}")
try:
# 消息接收循环
# 保持连接打开,监听客户端发来的消息
# 即使客户端不发消息,这个循环也必须保持,否则连接会关闭
while True:
# 等待接收客户端消息(阻塞等待)
data = await websocket.receive_json()
# 处理心跳 ping
# 前端每 30 秒发送一次 ping,后端回复 pong
# 作用:检测连接是否存活,防止中间代理(如 Nginx)因超时断开连接
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
logger.debug(f"WebSocket 心跳: agent_id={agent_id}")
# 处理输入指示器 typing 事件
# 前端在用户输入时发送 typing 事件,后端广播给同一会话的其他参与者
elif data.get("type") == "typing":
conversation_id = data.get("conversation_id")
sender_name = data.get("sender_name", agent_id)
if conversation_id:
# 广播给所有坐席(包含 sender_type 和 sender_id
# 前端可据此过滤掉自己的 typing 事件)
await ws_manager.broadcast({
"type": "typing",
"data": {
"conversation_id": conversation_id,
"sender_id": agent_id,
"sender_name": sender_name,
"sender_type": "agent",
}
})
else:
# 未来可扩展处理其他类型的客户端消息
logger.debug(
f"WebSocket 收到未知消息: agent_id={agent_id}, "
f"type={data.get('type', 'unknown')}"
)
except WebSocketDisconnect:
# 客户端主动断开连接(正常行为)
# 清理 ConnectionManager 中的注册信息
ws_manager.disconnect(agent_id)
logger.info(f"坐席断开 WebSocket 连接: agent_id={agent_id}")
except Exception as e:
# 其他异常(如网络错误、JSON 解析错误等)
# 确保注册信息被清理
ws_manager.disconnect(agent_id)
logger.warning(f"WebSocket 异常断开: agent_id={agent_id}, error={e}")
# ==========================================================================
# H5员工 WebSocket 端点
# ==========================================================================
@router.websocket("/ws/h5/{employee_id}")
async def h5_websocket_endpoint(
websocket: WebSocket,
employee_id: str,
request: Request,
) -> None:
"""H5员工 WebSocket 端点主循环(含 token 认证)。
做什么:
1. 从 Authorization header 获取 token(优先从)或 query param(兼容)
2. 验证 employee token 有效性(查 Redis
3. 验证 token 与 employee_id 一致性(防冒充)
4. 认证通过后接受连接,注册到 ConnectionManager 的员工连接表
5. 进入消息接收循环,处理心跳 ping
6. 连接断开时清理注册信息
为什么需要 H5 WS 连接:
- H5员工需要实时接收参与者变更事件(新参与者加入、有人退出等)
- 当前仅通过 3 秒轮询获取更新,实时性不足
- WS 推送 + 轮询降级,双通道保证消息可达
安全改进(P0-#4):
- 优先从 Authorization: Bearer {token} header 获取 token
- 兼容从 ?token= URL 参数获取(向后兼容)
认证机制(与坐席端一致):
- Redis 中存储格式: employee:token:{token} -> employee_id
- (与H5登录 API /api/h5/mock-login 存储格式一致)
- token 缺失、无效、过期、与 employee_id 不匹配均拒绝连接
Args:
websocket: FastAPI WebSocket 对象(框架自动注入)
employee_id: 员工企微 UserID(从 URL 路径参数获取)
request: Starlette Request(用于获取 header
"""
# ======================================================================
# Token 认证(从 subprotocol / header / query 获取)
# ======================================================================
# 步骤1: 优先从 Sec-WebSocket-Protocol (subprotocol) 获取 token,其次从 Authorization header,最后从 query(向后兼容)
# 格式: Sec-WebSocket-Protocol: bearer.{token}
subprotocol = request.headers.get("sec-websocket-protocol", "")
if subprotocol.startswith("bearer."):
token = subprotocol[7:] # 去掉 "bearer." 前缀
else:
# 其次从 Authorization header 获取
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:] # 去掉 "Bearer " 前缀
else:
# 向后兼容:从 query param 获取(即将废弃)
token = request.query_params.get("token", "")
# 步骤2: 检查 token 是否为空
if not token:
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Missing token")
logger.warning(f"H5 WebSocket 拒绝连接: employee_id={employee_id}, 原因=缺少token")
return
# 步骤3: 从 Redis 查询 token 对应的员工信息
# Redis 中存储格式: employee:token:{token} -> employee_id
# (与H5登录 API /api/h5/mock-login 存储格式一致)
try:
stored_employee_id = await cache_service.get(f"employee:token:{token}")
except Exception as e:
# Redis 不可用时必须拒绝连接(与坐席端一致的安全策略)
logger.error(f"Redis 查询失败,拒绝 H5 WS 连接: employee_id={employee_id}, error={e}")
await websocket.accept()
await websocket.close(
code=WS_CLOSE_UNAUTHORIZED,
reason="Authentication service unavailable"
)
return
# 步骤4: 验证 token 与 employee_id 一致性
if not stored_employee_id:
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Invalid or expired token")
logger.warning(f"H5 WebSocket 拒绝连接: employee_id={employee_id}, 原因=token无效或已过期")
return
if stored_employee_id != employee_id:
await websocket.accept()
await websocket.close(code=WS_CLOSE_UNAUTHORIZED, reason="Token-employee mismatch")
logger.warning(
f"H5 WebSocket 拒绝连接: employee_id={employee_id}, "
f"原因=token对应员工{stored_employee_id}与请求不匹配"
)
return
# ======================================================================
# 认证通过,建立连接
# ======================================================================
# 注册员工连接(内部会调用 websocket.accept()
await ws_manager.connect_employee(employee_id, websocket)
logger.info(f"H5员工 WebSocket 连接已认证: employee_id={employee_id}")
try:
# 消息接收循环
# H5员工端目前只发送心跳 ping,不需要发送 typing 等事件
while True:
data = await websocket.receive_json()
# 处理心跳 ping
if data.get("type") == "ping":
await websocket.send_json({"type": "pong"})
logger.debug(f"H5 WebSocket 心跳: employee_id={employee_id}")
else:
logger.debug(
f"H5 WebSocket 收到未知消息: employee_id={employee_id}, "
f"type={data.get('type', 'unknown')}"
)
except WebSocketDisconnect:
# 客户端主动断开连接
ws_manager.disconnect_employee(employee_id)
logger.info(f"H5员工断开 WebSocket 连接: employee_id={employee_id}")
except Exception as e:
# 其他异常
ws_manager.disconnect_employee(employee_id)
logger.warning(f"H5 WebSocket 异常断开: employee_id={employee_id}, error={e}")