Files
wecom_it_smart_desk/backend/app/main.py
T

534 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# 企微IT智能服务台 — FastAPI 应用入口
# =============================================================================
# 说明:FastAPI 应用的主入口文件,负责:
# 1. 创建 FastAPI 应用实例
# 2. 配置 CORS 跨域支持
# 3. 挂载 API 路由
# 4. 注册全局异常处理器
# 5. 添加启动事件(初始化默认数据)
# 6. 提供健康检查端点
# =============================================================================
import json
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
# 导入配置(读取环境变量)
from app.config import settings
# 导入路由汇总
from app.api.router import api_router
# 导入共享服务生命周期管理
from app.dependencies import init_shared_services, cleanup_shared_services
# 导入异常处理器和异常类
from app.utils.response import AppException, app_exception_handler
# 配置日志格式
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# --------------------------------------------------------------------------
# 应用生命周期管理(启动和关闭事件)
# --------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理。
在应用启动时执行初始化操作(如插入默认数据),
在应用关闭时执行清理操作。
"""
# ===== 启动事件 =====
logger.info("🚀 企微IT智能服务台启动中...")
# 校验关键配置项(防止生产环境忘记配置导致静默失败)
_validate_config()
# 初始化共享服务实例(Redis/AIService/WecomService/AIHandler
# 这些实例在应用运行期间复用,避免每次请求重新创建导致资源泄漏
await init_shared_services()
# 自动建表(开发阶段,生产环境应用 Alembic 迁移)
await _auto_create_tables()
# 初始化默认数据
await _init_default_data()
logger.info("✅ 企微IT智能服务台启动完成")
yield # 应用运行中
# ===== 关闭事件 =====
logger.info("👋 企微IT智能服务台关闭中...")
# 清理共享服务实例(关闭 Redis 连接、httpx 连接池等)
await cleanup_shared_services()
logger.info("✅ 企微IT智能服务台已关闭")
# --------------------------------------------------------------------------
# 配置校验(启动时检查关键配置项是否为占位符)
# --------------------------------------------------------------------------
# 占位符列表:这些默认值在 config.py 中设置,生产环境必须替换
_PLACEHOLDER_VALUES = {
"wecom_corp_id": "ww1234567890abcdef",
"wecom_secret": "your-agent-secret",
"wecom_token": "your-callback-token",
"wecom_encoding_aes_key": "your-aes-key-43-characters-long-encoding-key",
}
def _validate_config():
"""校验关键配置项是否为占位符。
生产环境部署时,如果忘记修改 config.py 中的占位符值,
会导致 AES 解密静默失败、企微 API 调用 400 等问题。
此函数在启动时检查这些关键配置,输出醒目警告。
"""
warnings = []
for key, placeholder in _PLACEHOLDER_VALUES.items():
actual_value = getattr(settings, key, "")
if actual_value == placeholder:
warnings.append(f" ⚠️ {key} = '{placeholder}' (未配置!)")
if warnings:
logger.warning(
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
"⚠️ 检测到以下关键配置仍为占位符,请修改 .env 或环境变量:\n"
+ "\n".join(warnings)
+ "\n"
" 企微回调消息将无法正常解密!\n"
" 参考 .env.example 或项目部署手册进行配置。\n"
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
)
else:
logger.info("✅ 关键配置校验通过")
# --------------------------------------------------------------------------
# 自动建表(开发阶段使用)
# --------------------------------------------------------------------------
async def _auto_create_tables():
"""自动创建所有数据库表。
开发阶段使用,根据模型定义自动创建表。
生产环境应使用 Alembic 迁移来管理表结构变更。
工作原理:
1. 获取 engine(懒加载)
2. 通过 Base.metadata 收集所有模型定义
3. 执行 CREATE TABLE IF NOT EXISTS
"""
from app.database import _get_engine, Base
# 导入所有模型,确保 Base.metadata 知道所有表的定义
# 如果不导入,Base.metadata 里只有基类,不会建任何表
import app.models # noqa: F401
engine = _get_engine()
async with engine.begin() as conn:
# checkfirst=True: 只创建不存在的表,不会覆盖已有表和数据
await conn.run_sync(Base.metadata.create_all, checkfirst=True)
logger.info("数据库表检查/创建完成")
# --------------------------------------------------------------------------
# 初始化默认数据
# --------------------------------------------------------------------------
async def _init_default_data():
"""初始化默认数据。
当数据库表为空时,插入预置配置数据,包括:
1. system_configs — 系统配置(关键词、阈值、话术等)
2. funny_phrases — 趣味话术
3. quick_reply_templates — 快速回复模板
4. approval_links — 审批流程链接
5. software_downloads — 软件下载入口
只在表为空时插入,避免重复插入。
"""
from app.database import _get_session_factory
from app.models.system_config import SystemConfig
from app.models.funny_phrase import FunnyPhrase
from app.models.quick_reply_template import QuickReplyTemplate
from app.models.approval_link import ApprovalLink
from app.models.software_download import SoftwareDownload
async_session_factory = _get_session_factory()
async with async_session_factory() as db:
try:
# 1. 初始化系统配置
await _init_system_configs(db, SystemConfig)
# 2. 初始化趣味话术
await _init_funny_phrases(db, FunnyPhrase)
# 3. 初始化快速回复模板
await _init_quick_reply_templates(db, QuickReplyTemplate)
# 4. 初始化审批流程链接
await _init_approval_links(db, ApprovalLink)
# 5. 初始化软件下载入口
await _init_software_downloads(db, SoftwareDownload)
await db.commit()
logger.info("默认数据初始化完成")
except Exception as e:
await db.rollback()
logger.error(f"默认数据初始化失败: {e}")
async def _init_system_configs(db, SystemConfig):
"""初始化系统配置项。"""
from sqlalchemy import select, func
count_stmt = select(func.count(SystemConfig.id))
result = await db.execute(count_stmt)
count = result.scalar() or 0
if count > 0:
logger.debug(f"system_configs 已有 {count} 条数据,跳过初始化")
return
configs = [
SystemConfig(config_key="hand_raise_keywords", config_value=json.dumps(["转人工", "人工", "人工服务", "真人", "客服", "帮我转人工", "找人工"], ensure_ascii=False), description="举手触发关键词"),
SystemConfig(config_key="emotion_keywords_angry", config_value=json.dumps(["崩溃", "愤怒", "投诉", "差劲", "垃圾", "太差了", "受不了"], ensure_ascii=False), description="愤怒情绪关键词"),
SystemConfig(config_key="emotion_keywords_urgent", config_value=json.dumps(["", "紧急", "马上", "立刻", "赶紧", "十万火急", "快点"], ensure_ascii=False), description="紧急情绪关键词"),
SystemConfig(config_key="emotion_keywords_worried", config_value=json.dumps(["担心", "害怕", "出错", "丢失", "完蛋", "糟糕"], ensure_ascii=False), description="担忧情绪关键词"),
SystemConfig(config_key="intervene_round_threshold", config_value="3", description="需介入追问轮次阈值"),
SystemConfig(config_key="urgency_base_keyword_score", config_value="1", description="关键词匹配基础加分"),
SystemConfig(config_key="urgency_emotion_bonus", config_value="1", description="情绪标记加成分"),
SystemConfig(config_key="urgency_vip_bonus", config_value="1", description="VIP加成分"),
SystemConfig(config_key="urgency_repeat_bonus", config_value="1", description="重复追问加成分"),
SystemConfig(config_key="polling_interval_seconds", config_value="3", description="坐席轮询间隔(秒)"),
SystemConfig(config_key="access_token_buffer_seconds", config_value="300", description="access_token提前刷新时间(秒)"),
SystemConfig(config_key="emergency_mode", config_value="false", description="应急模式开关(true=启用员工服务通道,智能服务台降级)"),
]
db.add_all(configs)
await db.flush()
logger.info(f"初始化 system_configs: {len(configs)}")
async def _init_funny_phrases(db, FunnyPhrase):
"""初始化趣味话术。"""
from sqlalchemy import select, func
count_stmt = select(func.count(FunnyPhrase.id))
result = await db.execute(count_stmt)
count = result.scalar() or 0
if count > 0:
logger.debug(f"funny_phrases 已有 {count} 条数据,跳过初始化")
return
phrases = [
FunnyPhrase(scene="shake", content="大哥,俺这就去摇人,稍等...", tone="亲切", sort_order=1),
FunnyPhrase(scene="keyword", content="收到!这就帮您摇位大神来", tone="稍正式", sort_order=1),
FunnyPhrase(scene="waiting", content="人还在路上,别急别急~", tone="安抚", sort_order=1),
FunnyPhrase(scene="connected", content="人摇来了!IT坐席为您服务", tone="明确交接", sort_order=1),
FunnyPhrase(scene="timeout", content="坐席都在忙,不过AI还在呢,要不先聊聊?我再继续摇", tone="降级安抚", sort_order=1),
FunnyPhrase(scene="vip", content="这就帮您安排专家,请稍候", tone="正式", sort_order=1),
]
db.add_all(phrases)
await db.flush()
logger.info(f"初始化 funny_phrases: {len(phrases)}")
async def _init_quick_reply_templates(db, QuickReplyTemplate):
"""初始化快速回复模板。"""
from sqlalchemy import select, func
count_stmt = select(func.count(QuickReplyTemplate.id))
result = await db.execute(count_stmt)
count = result.scalar() or 0
if count > 0:
logger.debug(f"quick_reply_templates 已有 {count} 条数据,跳过初始化")
return
templates = [
QuickReplyTemplate(category="账号", title="密码重置", content="您好{employee_name},您的密码重置链接已发送至您的企业邮箱,请在30分钟内完成操作。", variables=["employee_name"], sort_order=1),
QuickReplyTemplate(category="账号", title="账号解锁", content="您好,您的账号已解锁,请5分钟后重新尝试登录。如仍有问题请联系IT服务台。", variables=[], sort_order=2),
QuickReplyTemplate(category="网络", title="VPN连接指引", content="请按以下步骤操作:1.打开VPN客户端 2.选择\u201c公司内网\u201d 3.输入域账号密码 4.点击连接。详细图文教程请查看右侧\u201c操作步骤\u201d", variables=[], sort_order=3),
QuickReplyTemplate(category="网络", title="WiFi连接", content="公司WiFi名称:Office-5G,密码请咨询前台或查看工位标签。", variables=[], sort_order=4),
QuickReplyTemplate(category="软件", title="软件安装申请", content="您好,软件安装需要提交审批申请。请在右侧\u201c审批流程\u201d中点击\u201c软件安装申请\u201d链接提交。", variables=[], sort_order=5),
QuickReplyTemplate(category="硬件", title="设备报修", content="您好,设备报修请提交工单。请在右侧\u201c审批流程\u201d中点击\u201c设备报修\u201d链接提交,IT会在24小时内联系您。", variables=[], sort_order=6),
QuickReplyTemplate(category="通用", title="会话结束", content="您好,请问还有其他问题吗?如无其他问题,我将结束本次服务。祝您工作顺利!", variables=[], sort_order=7),
QuickReplyTemplate(category="通用", title="稍等回复", content="收到,我正在为您查询,请稍等片刻。", variables=[], sort_order=8),
]
db.add_all(templates)
await db.flush()
logger.info(f"初始化 quick_reply_templates: {len(templates)}")
async def _init_approval_links(db, ApprovalLink):
"""初始化审批流程链接。"""
from sqlalchemy import select, func
count_stmt = select(func.count(ApprovalLink.id))
result = await db.execute(count_stmt)
count = result.scalar() or 0
if count > 0:
logger.debug(f"approval_links 已有 {count} 条数据,跳过初始化")
return
links = [
ApprovalLink(category="IT", title="软件安装申请", url="https://审批系统地址/software-install", sort_order=1),
ApprovalLink(category="IT", title="设备报修工单", url="https://审批系统地址/device-repair", sort_order=2),
ApprovalLink(category="IT", title="VPN开通申请", url="https://审批系统地址/vpn-apply", sort_order=3),
ApprovalLink(category="IT", title="权限申请", url="https://审批系统地址/permission-apply", sort_order=4),
ApprovalLink(category="HR", title="入职手续", url="https://审批系统地址/onboarding", sort_order=5),
ApprovalLink(category="HR", title="离职手续", url="https://审批系统地址/offboarding", sort_order=6),
ApprovalLink(category="行政", title="办公用品申领", url="https://审批系统地址/office-supplies", sort_order=7),
ApprovalLink(category="财务", title="报销申请", url="https://审批系统地址/reimbursement", sort_order=8),
]
db.add_all(links)
await db.flush()
logger.info(f"初始化 approval_links: {len(links)}")
async def _init_software_downloads(db, SoftwareDownload):
"""初始化软件下载入口。"""
from sqlalchemy import select, func
count_stmt = select(func.count(SoftwareDownload.id))
result = await db.execute(count_stmt)
count = result.scalar() or 0
if count > 0:
logger.debug(f"software_downloads 已有 {count} 条数据,跳过初始化")
return
downloads = [
SoftwareDownload(category="办公", name="企业微信", version="最新版", platform="全平台", download_url="https://work.weixin.qq.com/#download", sort_order=1),
SoftwareDownload(category="办公", name="WPS Office", version="12.1", platform="Windows/Mac", download_url="https://www.wps.cn/download", sort_order=2),
SoftwareDownload(category="办公", name="Microsoft Teams", version="最新版", platform="全平台", download_url="https://www.microsoft.com/teams/download", sort_order=3),
SoftwareDownload(category="开发", name="VS Code", version="1.90", platform="Windows/Mac/Linux", download_url="https://code.visualstudio.com/download", sort_order=4),
SoftwareDownload(category="开发", name="Git", version="2.45", platform="Windows/Mac", download_url="https://git-scm.com/download", sort_order=5),
SoftwareDownload(category="安全", name="公司VPN客户端", version="3.2", platform="Windows/Mac", download_url="https://内部下载地址/vpn-client", sort_order=6),
SoftwareDownload(category="工具", name="7-Zip", version="24.06", platform="Windows", download_url="https://www.7-zip.org/download", sort_order=7),
SoftwareDownload(category="工具", name="PDF阅读器", version="最新版", platform="Windows/Mac", download_url="https://get.adobe.com/reader/", sort_order=8),
]
db.add_all(downloads)
await db.flush()
logger.info(f"初始化 software_downloads: {len(downloads)}")
# --------------------------------------------------------------------------
# 创建 FastAPI 应用
# --------------------------------------------------------------------------
def create_app() -> FastAPI:
"""创建并配置 FastAPI 应用实例。
使用工厂函数模式,方便测试时创建不同的应用实例。
Returns:
FastAPI: 配置好的应用实例
"""
# 创建 FastAPI 实例
# lifespan: 应用生命周期管理(启动/关闭事件)
app = FastAPI(
title="企微IT智能服务台",
description="基于企微自建应用消息API的IT服务坐席系统",
version="1.0.0",
lifespan=lifespan,
)
# ----------------------------------------------------------------------
# 配置 CORS(跨域资源共享)
# ----------------------------------------------------------------------
# 为什么需要 CORS:前端和后端运行在不同端口,浏览器会阻止跨域请求
# allow_origins: 允许的前端地址列表
# allow_credentials: 允许携带 Cookie
# allow_methods: 允许的 HTTP 方法(仅允许必要的方法)
# allow_headers: 允许的请求头(仅允许必要的头)
# ----------------------------------------------------------------------
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type", "X-Employee-Id"],
)
# ----------------------------------------------------------------------
# 速率限制(防止暴力破解和 DDoS)
# ----------------------------------------------------------------------
# slowapi 为每个 IP 维护请求计数器(默认内存后端)
# 登录接口严格限制(防暴力破解),普通接口宽松限制(防滥用)
# ----------------------------------------------------------------------
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.responses import JSONResponse as RateLimitJSONResponse
from app.utils.response import error_response as _rl_error_response
# 速率限制器:按客户端 IP 维度限制
# 移除 env_file=None 参数:slowapi 0.1.9 不支持该参数
# python-dotenv 已在应用启动时处理 .env 文件
limiter = Limiter(key_func=get_remote_address)
# 注册速率限制超限处理器
app.state.limiter = limiter
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request, exc: RateLimitExceeded):
"""速率限制超限响应:返回 429 状态码和友好提示。"""
return RateLimitJSONResponse(
status_code=429,
content=_rl_error_response(429, f"请求过于频繁,请 {exc.detail} 后重试"),
)
# ----------------------------------------------------------------------
# 注册全局异常处理器
# ----------------------------------------------------------------------
# 当业务逻辑抛出 AppException 时,自动转换为统一响应格式
# ----------------------------------------------------------------------
app.add_exception_handler(AppException, app_exception_handler)
# ----------------------------------------------------------------------
# 注册兜底异常处理器(捕获所有未预期的异常,避免裸 500)
# ----------------------------------------------------------------------
# 数据库连接失败、Redis 异常、第三方库错误等非 AppException 异常
# 都会被捕获并返回统一格式的错误响应,同时记录详细日志
# ----------------------------------------------------------------------
import traceback
from fastapi.responses import JSONResponse
from app.utils.response import error_response
@app.exception_handler(Exception)
async def catch_all_exception_handler(request, exc):
"""兜底异常处理器:捕获所有未预期异常。
安全处理:
- 详细异常信息记录到日志(供排查)
- 响应只返回通用错误信息(避免泄露内部细节)
"""
# 记录完整错误堆栈(用于排查问题)
logger.error(f"未预期异常: {exc}\n{traceback.format_exc()}")
# 返回统一格式的错误响应(HTTP 200 + 业务错误码)
# 安全:响应不包含具体异常信息,仅返回通用消息
return JSONResponse(
status_code=200,
content=error_response(1005, "服务器内部错误,请稍后重试或联系管理员")
)
# ----------------------------------------------------------------------
# 请求日志 + 兜底异常中间件
# ----------------------------------------------------------------------
# 使用中间件而非 exception_handler 来捕获所有异常
# 原因:FastAPI 的 @app.exception_handler(Exception) 在某些情况下
# 无法捕获异常(如依赖注入 yield 阶段的异常),而中间件更可靠
# ----------------------------------------------------------------------
import traceback as tb_module
from starlette.requests import Request
from starlette.responses import Response as StarletteResponse, JSONResponse as StarJSONResponse
from app.utils.response import error_response as _error_response
@app.middleware("http")
async def catch_errors_and_log(request: Request, call_next):
"""请求日志 + 兜底异常中间件。
1. 记录每个请求的方法、路径、状态码
2. 捕获所有未处理异常,返回统一格式的 JSON 错误响应
"""
# 使用 print 而非 logger,确保输出立即可见(调试阶段)
print(f">>> [MW] 收到请求: {request.method} {request.url.path}", flush=True)
try:
response: StarletteResponse = await call_next(request)
print(f"<<< [MW] 响应完成: {request.method} {request.url.path}{response.status_code}", flush=True)
return response
except Exception as e:
# 捕获所有未处理异常(包括依赖注入阶段的异常)
# 安全:详细日志仅记录,响应不泄露异常信息
error_tb = tb_module.format_exc()
print(f"!!! [MW] 未捕获异常: {request.method} {request.url.path}\n{error_tb}", flush=True)
logger.error(f"!!! 未捕获异常: {request.method} {request.url.path}\n{error_tb}")
# 返回统一格式的 JSON 错误响应(HTTP 200 + 业务错误码 1005
# 安全:响应不包含具体异常信息
return StarJSONResponse(
status_code=200,
content=_error_response(1005, "服务器内部错误,请稍后重试或联系管理员"),
)
# ----------------------------------------------------------------------
# 挂载 API 路由
# ----------------------------------------------------------------------
# 注意:nginx 已经通过 location /api/ 处理了前缀路由,
# 请求到达后端时 /api/ 已被 strip,因此此处不需要再加 /api 前缀
app.include_router(api_router)
# ----------------------------------------------------------------------
# 挂载 WebSocket 路由
# ----------------------------------------------------------------------
# WebSocket 端点不挂 /api 前缀,直接注册在根路径
# 原因:WebSocket 不是 REST API,前端通过 /ws/{agent_id} 连接
# Vite 开发服务器单独配置了 /ws 的 WebSocket 代理
# ----------------------------------------------------------------------
from app.api.ws import router as ws_router
app.include_router(ws_router)
# ----------------------------------------------------------------------
# 诊断端点(调试用,生产环境删除)
# ----------------------------------------------------------------------
@app.get("/test-ping", tags=["诊断"])
async def test_ping():
"""简单测试 — 不依赖数据库和 Redis"""
return {"code": 0, "message": "success", "data": {"message": "pong"}}
@app.get("/test-error", tags=["诊断"])
async def test_error():
"""测试异常处理 — 故意抛出异常"""
raise Exception("这是故意抛出的测试异常")
# ----------------------------------------------------------------------
# 健康检查端点
# ----------------------------------------------------------------------
# 用于 Docker 健康检查和负载均衡探针
# 返回简单的 JSON 表示服务正在运行
@app.get("/health", tags=["系统"])
async def health_check():
"""健康检查端点。
返回服务运行状态,用于:
- Docker 健康检查
- 负载均衡探针
- 监控系统检测服务是否存活
"""
return {"status": "ok", "service": "wecom-it-smart-desk"}
# ----------------------------------------------------------------------
# 打印所有已注册的路由(调试用)
# ----------------------------------------------------------------------
routes_info = []
for route in app.routes:
if hasattr(route, 'methods') and hasattr(route, 'path'):
routes_info.append(f" {', '.join(route.methods)} {route.path}")
if routes_info:
logger.info(f"已注册路由 ({len(routes_info)} 个):\n" + "\n".join(routes_info))
else:
logger.warning("⚠️ 没有注册任何路由!")
return app
# 创建应用实例(uvicorn 通过 app.main:app 引用此对象)
app = create_app()