chore: initial baseline with P0-safety .gitignore
This commit is contained in:
@@ -0,0 +1,556 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 消息管理 API
|
||||
# =============================================================================
|
||||
# 说明:坐席端的消息管理接口,包括:
|
||||
# 1. GET /api/conversations/{id}/messages — 获取会话消息列表(分页)
|
||||
# 2. POST /api/conversations/{id}/messages — 坐席发送消息
|
||||
# 3. GET /api/conversations/{id}/messages/poll — 坐席轮询新消息
|
||||
# 4. POST /api/messages/{id}/recall — 撤回消息(2分钟内)
|
||||
# 5. DELETE /api/messages/{id} — 删除消息
|
||||
# 6. POST /api/conversations/{id}/mark-read — 标记已读
|
||||
# 7. POST /api/messages/image — 上传图片
|
||||
# 8. POST /api/messages/file — 上传文件
|
||||
# 消息发送需同时:存数据库 + 调用企微API发送给员工
|
||||
# =============================================================================
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Query, UploadFile
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.agent import Agent
|
||||
from app.models.conversation import Conversation
|
||||
from app.models.message import Message
|
||||
from app.schemas.message import MessageCreate, MessageResponse
|
||||
from app.api.agents import get_current_agent
|
||||
from app.services.wecom_service import WecomService
|
||||
from app.utils.response import AppException, ERR_CONVERSATION_NOT_FOUND, ERR_CONVERSATION_RESOLVED, success_response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
# 文件大小限制:10MB
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# 可撤回时间窗口:2分钟
|
||||
RECALLABLE_WINDOW_MINUTES = 2
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# GET /api/conversations/{id}/messages — 获取会话消息列表
|
||||
# --------------------------------------------------------------------------
|
||||
@router.get("/conversations/{conversation_id}/messages")
|
||||
async def list_messages(
|
||||
conversation_id: str,
|
||||
limit: int = Query(50, ge=1, le=100, description="每页消息数量"),
|
||||
before: Optional[str] = Query(None, description="加载此消息ID之前的消息(向上翻页)"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""获取会话消息列表(分页)。
|
||||
|
||||
支持向上加载历史消息(通过 before 参数指定消息ID)。
|
||||
默认返回最新的 limit 条消息。
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
limit: 每页消息数量
|
||||
before: 加载此消息ID之前的消息(向上翻页)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式,包含消息列表和是否还有更多消息
|
||||
"""
|
||||
# 校验会话存在(UUID 转为字符串,兼容 SQLite String(36) 列)
|
||||
conv_id_str = str(conversation_id)
|
||||
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
|
||||
conv_result = await db.execute(conv_stmt)
|
||||
conversation = conv_result.scalars().first()
|
||||
if not conversation:
|
||||
raise ERR_CONVERSATION_NOT_FOUND
|
||||
|
||||
# 构建查询
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv_id_str
|
||||
).order_by(Message.created_at.desc())
|
||||
|
||||
# 如果指定了 before,只加载该消息之前的消息
|
||||
if before:
|
||||
try:
|
||||
before_uuid = str(UUID(before))
|
||||
# 先获取 before 消息的创建时间
|
||||
before_stmt = select(Message.created_at).where(Message.id == before_uuid)
|
||||
before_result = await db.execute(before_stmt)
|
||||
before_time = before_result.scalar_one_or_none()
|
||||
if before_time:
|
||||
stmt = stmt.where(Message.created_at < before_time)
|
||||
except ValueError:
|
||||
pass # before 参数格式错误,忽略
|
||||
|
||||
# 限制数量
|
||||
stmt = stmt.limit(limit + 1) # 多查一条判断是否还有更多
|
||||
|
||||
result = await db.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
|
||||
# 判断是否还有更多消息
|
||||
has_more = len(messages) > limit
|
||||
if has_more:
|
||||
messages = messages[:limit] # 去掉多查的那一条
|
||||
|
||||
# 按时间正序排列(最早的在前)
|
||||
messages.reverse()
|
||||
|
||||
# 标记消息为已读(坐席查看时自动标记)
|
||||
for msg in messages:
|
||||
if not msg.is_read and msg.sender_type == "employee":
|
||||
msg.is_read = True
|
||||
await db.flush()
|
||||
|
||||
# 转换为响应格式
|
||||
items = [MessageResponse.model_validate(m).model_dump() for m in messages]
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"items": items,
|
||||
"has_more": has_more,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# POST /api/conversations/{id}/messages — 坐席发送消息
|
||||
# --------------------------------------------------------------------------
|
||||
@router.post("/conversations/{conversation_id}/messages")
|
||||
async def send_message(
|
||||
conversation_id: str,
|
||||
body: MessageCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""坐席发送消息。
|
||||
|
||||
流程:
|
||||
1. 校验会话存在且未结单
|
||||
2. 将消息存入 messages 表
|
||||
3. 调用企微 API 发送消息给员工
|
||||
4. 更新会话的最后消息信息
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
body: 消息请求体(包含 content 和 msg_type)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式,包含发送的消息对象
|
||||
"""
|
||||
# 1. 校验会话(UUID 转为字符串,兼容 SQLite String(36) 列)
|
||||
conv_id_str = str(conversation_id)
|
||||
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
|
||||
conv_result = await db.execute(conv_stmt)
|
||||
conversation = conv_result.scalars().first()
|
||||
|
||||
if not conversation:
|
||||
raise ERR_CONVERSATION_NOT_FOUND
|
||||
if conversation.status == "resolved":
|
||||
raise ERR_CONVERSATION_RESOLVED
|
||||
|
||||
# 2. 创建消息记录
|
||||
# 从会话的 assigned_agent_id 获取坐席信息
|
||||
agent_id = conversation.assigned_agent_id or "unknown"
|
||||
|
||||
# 计算可撤回截止时间
|
||||
recallable_until = datetime.now() + timedelta(minutes=RECALLABLE_WINDOW_MINUTES)
|
||||
|
||||
message = Message(
|
||||
conversation_id=conv_id_str,
|
||||
sender_type="agent",
|
||||
sender_id=agent_id,
|
||||
sender_name="", # 坐席姓名,后续从坐席信息补充
|
||||
content=body.content,
|
||||
msg_type=body.msg_type,
|
||||
# M1 新增:文件上传相关字段
|
||||
media_url=body.media_url,
|
||||
file_name=body.file_name,
|
||||
file_size=body.file_size,
|
||||
# M1 新增:引用回复
|
||||
reply_to_id=body.reply_to_id,
|
||||
status="sending", # 初始状态为发送中
|
||||
recallable_until=recallable_until,
|
||||
is_read=True, # 坐席自己发的消息默认已读
|
||||
)
|
||||
db.add(message)
|
||||
|
||||
# 3. 更新会话最后消息信息
|
||||
conversation.last_message_at = datetime.now()
|
||||
conversation.last_message_summary = body.content[:256]
|
||||
conversation.updated_at = datetime.now()
|
||||
db.add(conversation)
|
||||
|
||||
await db.flush() # 刷新以获取消息 ID
|
||||
|
||||
# 4. 调用企微 API 发送消息给员工
|
||||
# 注意:只有 text 类型消息才需要调用企微 API 推送给员工
|
||||
# image/file 等非文本消息暂不通过企微推送(仅存储消息记录供坐席查看)
|
||||
# 跳过 Redis 连��可避免无谓的网络开销,减少截图发送超时
|
||||
if body.msg_type == "text":
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from app.config import settings
|
||||
|
||||
redis_client = settings.create_redis_client()
|
||||
wecom_service = WecomService(redis_client)
|
||||
|
||||
await wecom_service.send_text_message(
|
||||
conversation.employee_id, body.content
|
||||
)
|
||||
|
||||
await wecom_service.close()
|
||||
await redis_client.close()
|
||||
|
||||
except Exception as e:
|
||||
# 企微 API 调用失败不阻塞消息存储
|
||||
logger.warning(f"企微消息发送失败(消息已存储): {e}")
|
||||
|
||||
# 5. 更新消息状态为已发送
|
||||
message.status = "sent"
|
||||
await db.flush()
|
||||
|
||||
# 转换为响应格式
|
||||
response_data = MessageResponse.model_validate(message).model_dump()
|
||||
return success_response(data=response_data)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# GET /api/conversations/{id}/messages/poll — 坐席轮询新消息
|
||||
# --------------------------------------------------------------------------
|
||||
@router.get("/conversations/{conversation_id}/messages/poll")
|
||||
async def poll_messages(
|
||||
conversation_id: str,
|
||||
after_message_id: Optional[str] = Query(None, description="返回此消息ID之后的新消息"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""坐席轮询新消息。
|
||||
|
||||
前端每 3-5 秒调用一次,获取上次轮询后的新消息。
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
after_message_id: 上次轮询的最后一消息ID(返回此之后的消息)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式,包含新消息列表
|
||||
"""
|
||||
# 构建查询(UUID 转为字符串,兼容 SQLite String(36) 列)
|
||||
conv_id_str = str(conversation_id)
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conv_id_str
|
||||
).order_by(Message.created_at.asc())
|
||||
|
||||
# 如果指定了 after_message_id,只返回该ID之后的消息
|
||||
if after_message_id:
|
||||
try:
|
||||
# 获取 after_message 的创建时间
|
||||
# 注意:确保用字符串比较,避免SQLAlchemy把参数转成UUID导致类型不匹配
|
||||
after_stmt = select(Message.created_at).where(
|
||||
Message.id == str(after_message_id)
|
||||
)
|
||||
after_result = await db.execute(after_stmt)
|
||||
after_time = after_result.scalar_one_or_none()
|
||||
if after_time:
|
||||
stmt = stmt.where(Message.created_at > after_time)
|
||||
except Exception:
|
||||
pass # 参数格式错误或查询失败,忽略
|
||||
|
||||
result = await db.execute(stmt)
|
||||
messages = list(result.scalars().all())
|
||||
|
||||
# 标记员工消息为已读
|
||||
for msg in messages:
|
||||
if not msg.is_read and msg.sender_type == "employee":
|
||||
msg.is_read = True
|
||||
await db.flush()
|
||||
|
||||
# 转换为响应格式
|
||||
items = [MessageResponse.model_validate(m).model_dump() for m in messages]
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"items": items,
|
||||
"has_more": False, # 轮询接口不需要分页
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# POST /api/messages/{id}/recall — 撤回消息(2分钟内)
|
||||
# --------------------------------------------------------------------------
|
||||
@router.post("/messages/{message_id}/recall")
|
||||
async def recall_message(
|
||||
message_id: str,
|
||||
agent: Agent = Depends(get_current_agent),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""撤回消息(2分钟内)。
|
||||
|
||||
仅可撤回2分钟内坐席自己发送的消息。
|
||||
|
||||
P0-2 安全修复(2026-06-14 评审):
|
||||
此前完全无鉴权,任意 HTTP 客户端可调用此端点修改任意消息。
|
||||
现在依赖 get_current_agent 校验登录态,再校验 message.sender_id
|
||||
是否等于当前坐席的 user_id,防止越权撤回他人消息。
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
agent: 当前坐席(鉴权依赖注入)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式
|
||||
"""
|
||||
# 查询消息
|
||||
stmt = select(Message).where(Message.id == str(message_id))
|
||||
result = await db.execute(stmt)
|
||||
message = result.scalars().first()
|
||||
|
||||
if not message:
|
||||
raise AppException(code=404, message="消息不存在")
|
||||
|
||||
# 校验是否是坐席发送的消息
|
||||
if message.sender_type != "agent":
|
||||
raise AppException(code=403, message="只能撤回坐席发送的消息")
|
||||
|
||||
# P0-2 修复:校验是否是当前坐席自己发的
|
||||
if message.sender_id != agent.user_id:
|
||||
raise AppException(code=403, message="只能撤回自己的消息")
|
||||
|
||||
# 校验是否在可撤回时间窗口内
|
||||
if message.recallable_until and datetime.now() > message.recallable_until:
|
||||
raise AppException(code=403, message="消息已超过2分钟,无法撤回")
|
||||
|
||||
# 将消息内容置为空,表示已撤回
|
||||
message.content = "[消息已撤回]"
|
||||
message.status = "recalled"
|
||||
await db.flush()
|
||||
|
||||
return success_response(message="消息撤回成功")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# DELETE /api/messages/{id} — 删除消息
|
||||
# --------------------------------------------------------------------------
|
||||
@router.delete("/messages/{message_id}")
|
||||
async def delete_message(
|
||||
message_id: str,
|
||||
agent: Agent = Depends(get_current_agent),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""删除坐席自己发送的消息。
|
||||
|
||||
P0-3 安全修复(2026-06-14 评审):
|
||||
此前完全无鉴权,任意 HTTP 客户端可删除任意消息。
|
||||
现在依赖 get_current_agent 校验登录态,再校验消息是否属于当前坐席,
|
||||
防止越权删除他人/会话历史。
|
||||
|
||||
Args:
|
||||
message_id: 消息ID
|
||||
agent: 当前坐席(鉴权依赖注入)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式
|
||||
"""
|
||||
# 查询消息
|
||||
stmt = select(Message).where(Message.id == str(message_id))
|
||||
result = await db.execute(stmt)
|
||||
message = result.scalars().first()
|
||||
|
||||
if not message:
|
||||
raise AppException(code=404, message="消息不存在")
|
||||
|
||||
# P0-3 修复:仅允许坐席删除自己发送的消息
|
||||
if message.sender_type != "agent" or message.sender_id != agent.user_id:
|
||||
raise AppException(code=403, message="只能删除自己发送的消息")
|
||||
|
||||
# 删除消息
|
||||
await db.delete(message)
|
||||
await db.flush()
|
||||
|
||||
return success_response(message="消息删除成功")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# POST /api/conversations/{id}/mark-read — 标记已读
|
||||
# --------------------------------------------------------------------------
|
||||
@router.post("/conversations/{conversation_id}/mark-read")
|
||||
async def mark_read(
|
||||
conversation_id: str,
|
||||
agent: Agent = Depends(get_current_agent),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""标记会话中所有员工未读消息为已读。
|
||||
|
||||
P0-4 安全修复(2026-06-14 评审):
|
||||
此前完全无鉴权,任意 HTTP 客户端可标记任意会话为已读,
|
||||
会破坏"未读消息数"业务统计。
|
||||
现在依赖 get_current_agent 校验登录态,再校验当前坐席是会话的
|
||||
主责或协作坐席才允许标记,防止越权篡改未读状态。
|
||||
P2-3 修复:原 `.where(Message.is_read == False)` 是 Python 表达式比较
|
||||
永远为 False(不抛错但实际未过滤),SQLAlchemy 也会当成赋值表达式
|
||||
处理;改为 `is_(False)` 走 SQL 否定。
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
agent: 当前坐席(鉴权依赖注入)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式
|
||||
"""
|
||||
conv_id_str = str(conversation_id)
|
||||
|
||||
# P0-4 修复:先校验当前坐席有权访问此会话
|
||||
conv_stmt = select(Conversation).where(Conversation.id == conv_id_str)
|
||||
conv_result = await db.execute(conv_stmt)
|
||||
conversation = conv_result.scalars().first()
|
||||
if not conversation:
|
||||
raise ERR_CONVERSATION_NOT_FOUND
|
||||
|
||||
is_assigned = conversation.assigned_agent_id == agent.user_id
|
||||
is_collaborator = agent.user_id in (conversation.collaborating_agent_ids or [])
|
||||
if not (is_assigned or is_collaborator):
|
||||
raise AppException(code=403, message="您不是该会话的坐席,无权操作")
|
||||
|
||||
# P2-3 修复:使用 is_(False) 而非 == False
|
||||
# 更新该会话的所有员工未读消息为已读
|
||||
stmt = (
|
||||
update(Message)
|
||||
.where(Message.conversation_id == conv_id_str)
|
||||
.where(Message.sender_type == "employee")
|
||||
.where(Message.is_read.is_(False))
|
||||
.values(is_read=True, status="read")
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.flush()
|
||||
|
||||
return success_response(message="标记已读成功")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# POST /api/messages/image — 上传图片
|
||||
# --------------------------------------------------------------------------
|
||||
@router.post("/messages/image")
|
||||
async def upload_image(
|
||||
file: UploadFile = File(...),
|
||||
agent: Agent = Depends(get_current_agent),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""上传图片文件。
|
||||
|
||||
文件大小限制:10MB
|
||||
|
||||
Args:
|
||||
file: 图片文件
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式,包含文件URL和元数据
|
||||
"""
|
||||
# 校验文件大小
|
||||
file.file.seek(0, 2)
|
||||
file_size = file.file.tell()
|
||||
file.file.seek(0)
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
raise AppException(code=400, message=f"文件大小超过10MB限制")
|
||||
|
||||
# 校验文件类型
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
content_type = file.content_type
|
||||
if content_type not in allowed_types:
|
||||
raise AppException(code=400, message="不支持的图片格式")
|
||||
|
||||
# 生成保存路径
|
||||
import uuid as uuid_module
|
||||
file_ext = os.path.splitext(file.filename)[1] if file.filename else ".jpg"
|
||||
file_name = f"{uuid_module.uuid4()}{file_ext}"
|
||||
upload_dir = os.path.join("media", "images")
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
file_path = os.path.join(upload_dir, file_name)
|
||||
|
||||
# 保存文件
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 返回文件URL
|
||||
file_url = f"/media/images/{file_name}"
|
||||
return success_response(
|
||||
data={
|
||||
"url": file_url,
|
||||
"filename": file_name,
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# POST /api/messages/file — 上传文件
|
||||
# --------------------------------------------------------------------------
|
||||
@router.post("/messages/file")
|
||||
async def upload_message_file(
|
||||
file: UploadFile = File(...),
|
||||
agent: Agent = Depends(get_current_agent),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""上传普通文件。
|
||||
|
||||
文件大小限制:10MB
|
||||
|
||||
Args:
|
||||
file: 文件
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 统一响应格式,包含文件URL和元数据
|
||||
"""
|
||||
# 校��文��大小
|
||||
file.file.seek(0, 2)
|
||||
file_size = file.file.tell()
|
||||
file.file.seek(0)
|
||||
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
raise AppException(code=400, message=f"文件大小超过10MB限制")
|
||||
|
||||
# 生成保存路径
|
||||
import uuid as uuid_module
|
||||
original_name = file.filename or "file"
|
||||
file_ext = os.path.splitext(original_name)[1]
|
||||
file_name = f"{uuid_module.uuid4()}{file_ext}"
|
||||
upload_dir = os.path.join("media", "files")
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
file_path = os.path.join(upload_dir, file_name)
|
||||
|
||||
# 保存文件
|
||||
content = await file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# 返回文件URL
|
||||
file_url = f"/media/files/{file_name}"
|
||||
return success_response(
|
||||
data={
|
||||
"url": file_url,
|
||||
"filename": original_name,
|
||||
"file_size": file_size,
|
||||
"content_type": file.content_type,
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user