Files

228 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — AI Wingman API 路由
# =============================================================================
# 说明:坐席端 AI 智能副驾驶 API,包含 3 个核心端点:
# 1. POST /api/conversations/{id}/wingman/draft — 生成 AI 草稿回复
# 2. POST /api/conversations/{id}/wingman/summary — 生成会话自动摘要
# 3. POST /api/conversations/{id}/wingman/tags — 生成自动标签建议
#
# 所有端点需要坐席认证(get_current_agent
# =============================================================================
import logging
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.dependencies import dep_wingman_service
from app.models.agent import Agent
from app.models.conversation import Conversation
from app.models.message import Message
from app.services.wingman_service import WingmanService
from app.utils.response import ERR_NOT_FOUND, success_response
# 复用坐席认证依赖
from app.api.agents import get_current_agent
logger = logging.getLogger(__name__)
# 创建路由器
router = APIRouter()
# --------------------------------------------------------------------------
# 辅助函数
# --------------------------------------------------------------------------
async def _validate_conversation(
conversation_id: str,
agent: Agent,
db: AsyncSession,
) -> Conversation:
"""验证会话存在性并返回会话对象。
Args:
conversation_id: 会话ID
agent: 当前坐席
db: 数据库会话
Returns:
Conversation: 会话对象
Raises:
AppException: 会话不存在
"""
stmt = select(Conversation).where(Conversation.id == conversation_id)
result = await db.execute(stmt)
conversation = result.scalars().first()
if not conversation:
raise ERR_NOT_FOUND
return conversation
async def _get_recent_messages(
conversation_id: str,
db: AsyncSession,
limit: int = 20,
) -> list[dict]:
"""获取会话最近的消息历史(转换为字典列表)。
Args:
conversation_id: 会话ID
db: 数据库会话
limit: 获取的消息条数
Returns:
list[dict]: 消息字典列表
"""
stmt = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
messages = list(result.scalars().all())
# 按时间正序排列(最早的在前)
messages.reverse()
# 转换为字典列表
return [
{
"id": msg.id,
"sender_type": msg.sender_type,
"sender_name": msg.sender_name,
"content": msg.content,
"msg_type": msg.msg_type,
"created_at": msg.created_at.isoformat() if msg.created_at else "",
}
for msg in messages
]
# --------------------------------------------------------------------------
# POST /api/conversations/{conversation_id}/wingman/draft
# --------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/wingman/draft")
async def generate_draft(
conversation_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
wingman_service: WingmanService = Depends(dep_wingman_service),
):
"""生成 AI 草稿回复。
基于当前会话的消息历史,让 Wingman Agent 生成坐席可以采纳的草稿回复。
Args:
conversation_id: 会话ID
agent: 当前坐席(通过认证依赖注入)
db: 数据库会话
wingman_service: Wingman 服务实例
Returns:
Dict: 统一响应格式,包含草稿内容、置信度和推理说明
"""
# 1. 验证坐席身份 + 会话存在性
await _validate_conversation(conversation_id, agent, db)
# 2. 从数据库读取该会话的消息历史(最近 20 条)
messages = await _get_recent_messages(conversation_id, db, limit=20)
# 3. 调用 WingmanService 生成草稿
result = await wingman_service.generate_draft(
conversation_id=conversation_id,
messages=messages,
db=db,
)
return success_response(data=result)
# --------------------------------------------------------------------------
# POST /api/conversations/{conversation_id}/wingman/summary
# --------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/wingman/summary")
async def generate_summary(
conversation_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
wingman_service: WingmanService = Depends(dep_wingman_service),
):
"""生成会话自动摘要。
基于完整对话生成结构化摘要,包含问题、原因、解决方案。
通常在结单时调用。
Args:
conversation_id: 会话ID
agent: 当前坐席
db: 数据库会话
wingman_service: Wingman 服务实例
Returns:
Dict: 统一响应格式,包含问题、原因、解决方案
"""
# 1. 验证坐席身份 + 会话存在性
await _validate_conversation(conversation_id, agent, db)
# 2. 从数据库读取该会话的完整消息历史(最多 50 条)
messages = await _get_recent_messages(conversation_id, db, limit=50)
# 3. 调用 WingmanService 生成摘要
result = await wingman_service.generate_summary(
conversation_id=conversation_id,
messages=messages,
)
return success_response(data=result)
# --------------------------------------------------------------------------
# POST /api/conversations/{conversation_id}/wingman/tags
# --------------------------------------------------------------------------
@router.post("/conversations/{conversation_id}/wingman/tags")
async def suggest_tags(
conversation_id: str,
agent: Agent = Depends(get_current_agent),
db: AsyncSession = Depends(get_db),
wingman_service: WingmanService = Depends(dep_wingman_service),
):
"""生成自动标签建议。
基于对话内容建议标签分类,包含标签列表、分类和优先级。
Args:
conversation_id: 会话ID
agent: 当前坐席
db: 数据库会话
wingman_service: Wingman 服务实例
Returns:
Dict: 统一响应格式,包含建议标签、分类和优先级
"""
# 1. 验证坐席身份 + 会话存在性
conversation = await _validate_conversation(conversation_id, agent, db)
# 2. 从数据库读取该会话的消息历史(最近 20 条)
messages = await _get_recent_messages(conversation_id, db, limit=20)
# 3. 获取已有标签(用于避免重复建议)
existing_tags = {}
if hasattr(conversation, 'tags') and conversation.tags:
existing_tags = conversation.tags if isinstance(conversation.tags, dict) else {}
# 4. 调用 WingmanService 生成标签建议
result = await wingman_service.suggest_tags(
conversation_id=conversation_id,
messages=messages,
existing_tags=existing_tags,
)
return success_response(data=result)