267 lines
6.9 KiB
Python
267 lines
6.9 KiB
Python
|
|
# =============================================================================
|
||
|
|
# 企微IT智能服务台 — 统一认证依赖
|
||
|
|
# =============================================================================
|
||
|
|
# 说明:提供统一的认证依赖函数,支持:
|
||
|
|
# 1. get_current_user: 获取当前用户信息(包含角色)
|
||
|
|
# 2. require_role: 角色验证装饰器
|
||
|
|
# 3. require_admin: 管理员权限验证
|
||
|
|
# =============================================================================
|
||
|
|
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from functools import wraps
|
||
|
|
from typing import List, Optional
|
||
|
|
|
||
|
|
import redis.asyncio as aioredis
|
||
|
|
from fastapi import Depends, HTTPException, status
|
||
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||
|
|
|
||
|
|
from app.config import settings
|
||
|
|
from app.services.token_service import TokenService
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# HTTP Bearer 认证方案
|
||
|
|
security = HTTPBearer()
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class UserInfo:
|
||
|
|
"""用户信息数据类。
|
||
|
|
|
||
|
|
Attributes:
|
||
|
|
employee_id: 企微 UserID
|
||
|
|
name: 用户姓名
|
||
|
|
department: 部门
|
||
|
|
avatar: 头像URL
|
||
|
|
roles: 角色列表
|
||
|
|
current_role: 当前选择的角色
|
||
|
|
login_source: 登录来源
|
||
|
|
"""
|
||
|
|
|
||
|
|
employee_id: str
|
||
|
|
name: str
|
||
|
|
department: str
|
||
|
|
avatar: str
|
||
|
|
roles: List[str]
|
||
|
|
current_role: str
|
||
|
|
login_source: str
|
||
|
|
|
||
|
|
|
||
|
|
# Redis 连接池(单例)
|
||
|
|
_redis_pool: Optional[aioredis.Redis] = None
|
||
|
|
|
||
|
|
|
||
|
|
async def get_redis() -> aioredis.Redis:
|
||
|
|
"""获取 Redis 连接。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
aioredis.Redis: Redis 异步客户端
|
||
|
|
"""
|
||
|
|
global _redis_pool
|
||
|
|
if _redis_pool is None:
|
||
|
|
_redis_pool = settings.create_redis_client()
|
||
|
|
return _redis_pool
|
||
|
|
|
||
|
|
|
||
|
|
# 共享服务实例(用于 wecom_callback.py 等模块)
|
||
|
|
# 这些函数提供同步获取服务实例的方式,用于非 FastAPI DI 的场景
|
||
|
|
def get_shared_redis() -> aioredis.Redis:
|
||
|
|
"""获取 Redis 客户端(同步版本,用于非 async 场景)。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
aioredis.Redis: Redis 客户端实例
|
||
|
|
"""
|
||
|
|
return settings.create_redis_client()
|
||
|
|
|
||
|
|
|
||
|
|
def get_shared_wecom_service():
|
||
|
|
"""获取 WecomService 共享实例。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
WecomService: 企微服务实例
|
||
|
|
"""
|
||
|
|
from app.services.wecom_service import WecomService
|
||
|
|
return WecomService(settings.create_redis_client())
|
||
|
|
|
||
|
|
|
||
|
|
def get_shared_ai_handler():
|
||
|
|
"""获取 AIHandler 共享实例。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
AIHandler: AI 处理器实例
|
||
|
|
"""
|
||
|
|
from app.services.ai_handler import AIHandler
|
||
|
|
from app.services.ai_service import AIService
|
||
|
|
return AIHandler(ai_service=AIService())
|
||
|
|
|
||
|
|
|
||
|
|
# FastAPI Depends 函数(用于路由依赖注入)
|
||
|
|
async def dep_redis() -> aioredis.Redis:
|
||
|
|
"""Redis 客户端依赖注入。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
aioredis.Redis: Redis 异步客户端
|
||
|
|
"""
|
||
|
|
return await get_redis()
|
||
|
|
|
||
|
|
|
||
|
|
def dep_wecom_service():
|
||
|
|
"""WecomService 依赖注入。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
WecomService: 企微服务实例
|
||
|
|
"""
|
||
|
|
from app.services.wecom_service import WecomService
|
||
|
|
return WecomService(settings.create_redis_client())
|
||
|
|
|
||
|
|
|
||
|
|
def dep_ai_handler():
|
||
|
|
"""AIHandler 依赖注入。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
AIHandler: AI 处理器实例
|
||
|
|
"""
|
||
|
|
from app.services.ai_handler import AIHandler
|
||
|
|
from app.services.ai_service import AIService
|
||
|
|
return AIHandler(ai_service=AIService())
|
||
|
|
|
||
|
|
|
||
|
|
def dep_wingman_service():
|
||
|
|
"""WingmanService 依赖注入。
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
WingmanService: AI Wingman 服务实例
|
||
|
|
"""
|
||
|
|
from app.services.wingman_service import WingmanService
|
||
|
|
return WingmanService()
|
||
|
|
|
||
|
|
|
||
|
|
# 应用生命周期管理函数
|
||
|
|
async def init_shared_services():
|
||
|
|
"""初始化共享服务(应用启动时调用)。
|
||
|
|
|
||
|
|
创建 Redis 连接池,初始化共享服务实例。
|
||
|
|
"""
|
||
|
|
global _redis_pool
|
||
|
|
_redis_pool = settings.create_redis_client()
|
||
|
|
logger.info("共享服务初始化完成")
|
||
|
|
|
||
|
|
|
||
|
|
async def cleanup_shared_services():
|
||
|
|
"""清理共享服务(应用关闭时调用)。
|
||
|
|
|
||
|
|
关闭 Redis 连接池。
|
||
|
|
"""
|
||
|
|
global _redis_pool
|
||
|
|
if _redis_pool:
|
||
|
|
await _redis_pool.close()
|
||
|
|
_redis_pool = None
|
||
|
|
logger.info("共享服务清理完成")
|
||
|
|
|
||
|
|
|
||
|
|
async def get_current_user(
|
||
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||
|
|
) -> UserInfo:
|
||
|
|
"""统一认证依赖:从 Token 获取用户信息。
|
||
|
|
|
||
|
|
支持新旧两种 Token 格式。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
credentials: HTTP Bearer Token
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
UserInfo: 用户信息
|
||
|
|
|
||
|
|
Raises:
|
||
|
|
HTTPException: Token 无效或已过期
|
||
|
|
"""
|
||
|
|
token = credentials.credentials
|
||
|
|
|
||
|
|
# 获取 Redis 连接
|
||
|
|
redis_client = await get_redis()
|
||
|
|
|
||
|
|
# 创建 Token 服务
|
||
|
|
token_service = TokenService(redis_client)
|
||
|
|
|
||
|
|
# 获取用户信息
|
||
|
|
user_info = await token_service.get_user_info(token)
|
||
|
|
|
||
|
|
if not user_info:
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
|
detail="Token 无效或已过期",
|
||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
||
|
|
)
|
||
|
|
|
||
|
|
return UserInfo(
|
||
|
|
employee_id=user_info["employee_id"],
|
||
|
|
name=user_info.get("name", ""),
|
||
|
|
department=user_info.get("department", ""),
|
||
|
|
avatar=user_info.get("avatar", ""),
|
||
|
|
roles=user_info.get("roles", ["user"]),
|
||
|
|
current_role=user_info.get("current_role", "user"),
|
||
|
|
login_source=user_info.get("login_source", "portal"),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def require_role(*required_roles: str):
|
||
|
|
"""角色验证装饰器。
|
||
|
|
|
||
|
|
检查用户是否拥有指定角色之一。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
*required_roles: 允许的角色列表
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
装饰器函数
|
||
|
|
|
||
|
|
Example:
|
||
|
|
@router.get("/api/admin/dashboard")
|
||
|
|
@require_role("admin")
|
||
|
|
async def get_dashboard(current_user: UserInfo = Depends(get_current_user)):
|
||
|
|
pass
|
||
|
|
"""
|
||
|
|
|
||
|
|
def decorator(func):
|
||
|
|
@wraps(func)
|
||
|
|
async def wrapper(
|
||
|
|
*args,
|
||
|
|
current_user: UserInfo = Depends(get_current_user),
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
# 检查用户是否有任一所需角色
|
||
|
|
user_roles = set(current_user.roles)
|
||
|
|
required = set(required_roles)
|
||
|
|
|
||
|
|
if not user_roles.intersection(required):
|
||
|
|
logger.warning(
|
||
|
|
f"用户 {current_user.employee_id} 角色不足: "
|
||
|
|
f"拥有 {current_user.roles}, 需要 {required_roles}"
|
||
|
|
)
|
||
|
|
raise HTTPException(
|
||
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
||
|
|
detail=f"需要以下角色之一: {', '.join(required_roles)}",
|
||
|
|
)
|
||
|
|
|
||
|
|
return await func(*args, current_user=current_user, **kwargs)
|
||
|
|
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
return decorator
|
||
|
|
|
||
|
|
|
||
|
|
def require_admin(func):
|
||
|
|
"""管理员权限验证装饰器。
|
||
|
|
|
||
|
|
等同于 @require_role("admin")。
|
||
|
|
|
||
|
|
Example:
|
||
|
|
@router.get("/api/admin/dashboard")
|
||
|
|
@require_admin
|
||
|
|
async def get_dashboard(current_user: UserInfo = Depends(get_current_user)):
|
||
|
|
pass
|
||
|
|
"""
|
||
|
|
return require_role("admin")(func)
|