Files

556 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — 消息管理 API
# =============================================================================
# 说明:坐席端的消息管理接口,包括:
# 1. GET /api/conversations/{id}/messages — 获取会话消息列表(分页)
# 2. POST /api/conversations/{id}/messages — 坐席发送消息
# 3. GET /api/conversations/{id}/messages/poll — 坐席轮询新消息
# 4. POST /api/messages/{id}/recall — 撤回消息(2分钟内)
# 5. DELETE /api/messages/{id} — 删除消息
# 6. POST /api/conversations/{id}/mark-read — 标记已读
# 7. POST /api/messages/image — 上传图片
# 8. POST /api/messages/file — 上传文件
# 消息发送需同时:存数据库 + 调用企微API发送给员工
# =============================================================================
import logging
import os
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, File, Query, UploadFile
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.schemas.message import MessageCreate, MessageResponse
from app.api.agents import get_current_agent
from app.services.wecom_service import WecomService
from app.utils.response import AppException, ERR_CONVERSATION_NOT_FOUND, ERR_CONVERSATION_RESOLVED, success_response
logger = logging.getLogger(__name__)
# 创建路由器
router = APIRouter()
# 文件大小限制:10MB
MAX_FILE_SIZE = 10 * 1024 * 1024
# 可撤回时间窗口:2分钟
RECALLABLE_WINDOW_MINUTES = 2
# --------------------------------------------------------------------------
# GET /api/conversations/{id}/messages — 获取会话消息列表
# --------------------------------------------------------------------------
@router.get("/conversations/{conversation_id}/messages")
async def list_messages(
conversation_id: str,
limit: int = Query(50, ge=1, le=100, description="每页消息数量"),
before: Optional[str] = Query(None, description="加载此消息ID之前的消息(向上翻页)"),
db: AsyncSession = Depends(get_db),
):
"""获取会话消息列表(分页)。
支持向上加载历史消息(通过 before 参数指定消息ID)。
默认返回最新的 limit 条消息。
Args:
conversation_id: 会话ID
limit: 每页消息数量
before: 加载此消息ID之前的消息(向上翻页)
db: 数据库会话
Returns:
Dict: 统一响应格式,包含消息列表和是否还有更多消息
"""
# 校验会话存在(UUID 转为字符串,兼容 SQLite String(36) 列)
conv_id_str = str(conversation_id)
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
conv_result = await db.execute(conv_stmt)
conversation = conv_result.scalars().first()
if not conversation:
raise ERR_CONVERSATION_NOT_FOUND
# 构建查询
stmt = select(Message).where(
Message.conversation_id == conv_id_str
).order_by(Message.created_at.desc())
# 如果指定了 before,只加载该消息之前的消息
if before:
try:
before_uuid = str(UUID(before))
# 先获取 before 消息的创建时间
before_stmt = select(Message.created_at).where(Message.id == before_uuid)
before_result = await db.execute(before_stmt)
before_time = before_result.scalar_one_or_none()
if before_time:
stmt = stmt.where(Message.created_at < before_time)
except ValueError:
pass # before 参数格式错误,忽略
# 限制数量
stmt = stmt.limit(limit + 1) # 多查一条判断是否还有更多
result = await db.execute(stmt)
messages = list(result.scalars().all())
# 判断是否还有更多消息
has_more = len(messages) > limit
if has_more:
messages = messages[:limit] # 去掉多查的那一条
# 按时间正序排列(最早的在前)
messages.reverse()
# 标记消息为已读(坐席查看时自动标记)
for msg in messages:
if not msg.is_read and msg.sender_type == "employee":
msg.is_read = True
await db.flush()
# 转换为响应格式
items = [MessageResponse.model_validate(m).model_dump() for m in messages]
return success_response(
data={
"items": items,
"has_more": has_more,
}
)
# --------------------------------------------------------------------------
# POST /api/conversations/{id}/messages — 坐席发送消息
# --------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/messages")
async def send_message(
conversation_id: str,
body: MessageCreate,
db: AsyncSession = Depends(get_db),
):
"""坐席发送消息。
流程:
1. 校验会话存在且未结单
2. 将消息存入 messages 表
3. 调用企微 API 发送消息给员工
4. 更新会话的最后消息信息
Args:
conversation_id: 会话ID
body: 消息请求体(包含 content 和 msg_type
db: 数据库会话
Returns:
Dict: 统一响应格式,包含发送的消息对象
"""
# 1. 校验会话(UUID 转为字符串,兼容 SQLite String(36) 列)
conv_id_str = str(conversation_id)
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
conv_result = await db.execute(conv_stmt)
conversation = conv_result.scalars().first()
if not conversation:
raise ERR_CONVERSATION_NOT_FOUND
if conversation.status == "resolved":
raise ERR_CONVERSATION_RESOLVED
# 2. 创建消息记录
# 从会话的 assigned_agent_id 获取坐席信息
agent_id = conversation.assigned_agent_id or "unknown"
# 计算可撤回截止时间
recallable_until = datetime.now() + timedelta(minutes=RECALLABLE_WINDOW_MINUTES)
message = Message(
conversation_id=conv_id_str,
sender_type="agent",
sender_id=agent_id,
sender_name="", # 坐席姓名,后续从坐席信息补充
content=body.content,
msg_type=body.msg_type,
# M1 新增:文件上传相关字段
media_url=body.media_url,
file_name=body.file_name,
file_size=body.file_size,
# M1 新增:引用回复
reply_to_id=body.reply_to_id,
status="sending", # 初始状态为发送中
recallable_until=recallable_until,
is_read=True, # 坐席自己发的消息默认已读
)
db.add(message)
# 3. 更新会话最后消息信息
conversation.last_message_at = datetime.now()
conversation.last_message_summary = body.content[:256]
conversation.updated_at = datetime.now()
db.add(conversation)
await db.flush() # 刷新以获取消息 ID
# 4. 调用企微 API 发送消息给员工
# 注意:只有 text 类型消息才需要调用企微 API 推送给员工
# image/file 等非文本消息暂不通过企微推送(仅存储消息记录供坐席查看)
# 跳过 Redis 连可避免无谓的网络开销,减少截图发送超时
if body.msg_type == "text":
try:
import redis.asyncio as aioredis
from app.config import settings
redis_client = settings.create_redis_client()
wecom_service = WecomService(redis_client)
await wecom_service.send_text_message(
conversation.employee_id, body.content
)
await wecom_service.close()
await redis_client.close()
except Exception as e:
# 企微 API 调用失败不阻塞消息存储
logger.warning(f"企微消息发送失败(消息已存储): {e}")
# 5. 更新消息状态为已发送
message.status = "sent"
await db.flush()
# 转换为响应格式
response_data = MessageResponse.model_validate(message).model_dump()
return success_response(data=response_data)
# --------------------------------------------------------------------------
# GET /api/conversations/{id}/messages/poll — 坐席轮询新消息
# --------------------------------------------------------------------------
@router.get("/conversations/{conversation_id}/messages/poll")
async def poll_messages(
conversation_id: str,
after_message_id: Optional[str] = Query(None, description="返回此消息ID之后的新消息"),
db: AsyncSession = Depends(get_db),
):
"""坐席轮询新消息。
前端每 3-5 秒调用一次,获取上次轮询后的新消息。
Args:
conversation_id: 会话ID
after_message_id: 上次轮询的最后一消息ID(返回此之后的消息)
db: 数据库会话
Returns:
Dict: 统一响应格式,包含新消息列表
"""
# 构建查询(UUID 转为字符串,兼容 SQLite String(36) 列)
conv_id_str = str(conversation_id)
stmt = select(Message).where(
Message.conversation_id == conv_id_str
).order_by(Message.created_at.asc())
# 如果指定了 after_message_id,只返回该ID之后的消息
if after_message_id:
try:
# 获取 after_message 的创建时间
# 注意:确保用字符串比较,避免SQLAlchemy把参数转成UUID导致类型不匹配
after_stmt = select(Message.created_at).where(
Message.id == str(after_message_id)
)
after_result = await db.execute(after_stmt)
after_time = after_result.scalar_one_or_none()
if after_time:
stmt = stmt.where(Message.created_at > after_time)
except Exception:
pass # 参数格式错误或查询失败,忽略
result = await db.execute(stmt)
messages = list(result.scalars().all())
# 标记员工消息为已读
for msg in messages:
if not msg.is_read and msg.sender_type == "employee":
msg.is_read = True
await db.flush()
# 转换为响应格式
items = [MessageResponse.model_validate(m).model_dump() for m in messages]
return success_response(
data={
"items": items,
"has_more": False, # 轮询接口不需要分页
}
)
# --------------------------------------------------------------------------
# POST /api/messages/{id}/recall — 撤回消息(2分钟内)
# --------------------------------------------------------------------------
@router.post("/messages/{message_id}/recall")
async def recall_message(
message_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
):
"""撤回消息(2分钟内)。
仅可撤回2分钟内坐席自己发送的消息。
P0-2 安全修复(2026-06-14 评审):
此前完全无鉴权,任意 HTTP 客户端可调用此端点修改任意消息。
现在依赖 get_current_agent 校验登录态,再校验 message.sender_id
是否等于当前坐席的 user_id,防止越权撤回他人消息。
Args:
message_id: 消息ID
agent: 当前坐席(鉴权依赖注入)
db: 数据库会话
Returns:
Dict: 统一响应格式
"""
# 查询消息
stmt = select(Message).where(Message.id == str(message_id))
result = await db.execute(stmt)
message = result.scalars().first()
if not message:
raise AppException(code=404, message="消息不存在")
# 校验是否是坐席发送的消息
if message.sender_type != "agent":
raise AppException(code=403, message="只能撤回坐席发送的消息")
# P0-2 修复:校验是否是当前坐席自己发的
if message.sender_id != agent.user_id:
raise AppException(code=403, message="只能撤回自己的消息")
# 校验是否在可撤回时间窗口内
if message.recallable_until and datetime.now() > message.recallable_until:
raise AppException(code=403, message="消息已超过2分钟,无法撤回")
# 将消息内容置为空,表示已撤回
message.content = "[消息已撤回]"
message.status = "recalled"
await db.flush()
return success_response(message="消息撤回成功")
# --------------------------------------------------------------------------
# DELETE /api/messages/{id} — 删除消息
# --------------------------------------------------------------------------
@router.delete("/messages/{message_id}")
async def delete_message(
message_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
):
"""删除坐席自己发送的消息。
P0-3 安全修复(2026-06-14 评审):
此前完全无鉴权,任意 HTTP 客户端可删除任意消息。
现在依赖 get_current_agent 校验登录态,再校验消息是否属于当前坐席,
防止越权删除他人/会话历史。
Args:
message_id: 消息ID
agent: 当前坐席(鉴权依赖注入)
db: 数据库会话
Returns:
Dict: 统一响应格式
"""
# 查询消息
stmt = select(Message).where(Message.id == str(message_id))
result = await db.execute(stmt)
message = result.scalars().first()
if not message:
raise AppException(code=404, message="消息不存在")
# P0-3 修复:仅允许坐席删除自己发送的消息
if message.sender_type != "agent" or message.sender_id != agent.user_id:
raise AppException(code=403, message="只能删除自己发送的消息")
# 删除消息
await db.delete(message)
await db.flush()
return success_response(message="消息删除成功")
# --------------------------------------------------------------------------
# POST /api/conversations/{id}/mark-read — 标记已读
# --------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/mark-read")
async def mark_read(
conversation_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
):
"""标记会话中所有员工未读消息为已读。
P0-4 安全修复(2026-06-14 评审):
此前完全无鉴权,任意 HTTP 客户端可标记任意会话为已读,
会破坏"未读消息数"业务统计。
现在依赖 get_current_agent 校验登录态,再校验当前坐席是会话的
主责或协作坐席才允许标记,防止越权篡改未读状态。
P2-3 修复:原 `.where(Message.is_read == False)` 是 Python 表达式比较
永远为 False(不抛错但实际未过滤),SQLAlchemy 也会当成赋值表达式
处理;改为 `is_(False)` 走 SQL 否定。
Args:
conversation_id: 会话ID
agent: 当前坐席(鉴权依赖注入)
db: 数据库会话
Returns:
Dict: 统一响应格式
"""
conv_id_str = str(conversation_id)
# P0-4 修复:先校验当前坐席有权访问此会话
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
conv_result = await db.execute(conv_stmt)
conversation = conv_result.scalars().first()
if not conversation:
raise ERR_CONVERSATION_NOT_FOUND
is_assigned = conversation.assigned_agent_id == agent.user_id
is_collaborator = agent.user_id in (conversation.collaborating_agent_ids or [])
if not (is_assigned or is_collaborator):
raise AppException(code=403, message="您不是该会话的坐席,无权操作")
# P2-3 修复:使用 is_(False) 而非 == False
# 更新该会话的所有员工未读消息为已读
stmt = (
update(Message)
.where(Message.conversation_id == conv_id_str)
.where(Message.sender_type == "employee")
.where(Message.is_read.is_(False))
.values(is_read=True, status="read")
)
await db.execute(stmt)
await db.flush()
return success_response(message="标记已读成功")
# --------------------------------------------------------------------------
# POST /api/messages/image — 上传图片
# --------------------------------------------------------------------------
@router.post("/messages/image")
async def upload_image(
file: UploadFile = File(...),
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
):
"""上传图片文件。
文件大小限制:10MB
Args:
file: 图片文件
db: 数据库会话
Returns:
Dict: 统一响应格式,包含文件URL和元数据
"""
# 校验文件大小
file.file.seek(0, 2)
file_size = file.file.tell()
file.file.seek(0)
if file_size > MAX_FILE_SIZE:
raise AppException(code=400, message=f"文件大小超过10MB限制")
# 校验文件类型
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
content_type = file.content_type
if content_type not in allowed_types:
raise AppException(code=400, message="不支持的图片格式")
# 生成保存路径
import uuid as uuid_module
file_ext = os.path.splitext(file.filename)[1] if file.filename else ".jpg"
file_name = f"{uuid_module.uuid4()}{file_ext}"
upload_dir = os.path.join("media", "images")
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, file_name)
# 保存文件
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# 返回文件URL
file_url = f"/media/images/{file_name}"
return success_response(
data={
"url": file_url,
"filename": file_name,
"file_size": file_size,
"content_type": content_type,
}
)
# --------------------------------------------------------------------------
# POST /api/messages/file — 上传文件
# --------------------------------------------------------------------------
@router.post("/messages/file")
async def upload_message_file(
file: UploadFile = File(...),
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
):
"""上传普通文件。
文件大小限制:10MB
Args:
file: 文件
db: 数据库会话
Returns:
Dict: 统一响应格式,包含文件URL和元数据
"""
# 校大小
file.file.seek(0, 2)
file_size = file.file.tell()
file.file.seek(0)
if file_size > MAX_FILE_SIZE:
raise AppException(code=400, message=f"文件大小超过10MB限制")
# 生成保存路径
import uuid as uuid_module
original_name = file.filename or "file"
file_ext = os.path.splitext(original_name)[1]
file_name = f"{uuid_module.uuid4()}{file_ext}"
upload_dir = os.path.join("media", "files")
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, file_name)
# 保存文件
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
# 返回文件URL
file_url = f"/media/files/{file_name}"
return success_response(
data={
"url": file_url,
"filename": original_name,
"file_size": file_size,
"content_type": file.content_type,
}
)