351 lines
12 KiB
Python
351 lines
12 KiB
Python
# =============================================================================
|
||
# 企微IT智能服务台 — 角色映射服务
|
||
# =============================================================================
|
||
# 说明:处理角色自动映射逻辑,支持以下来源:
|
||
# 1. 企微标签映射(wecom_tag)
|
||
# 2. eHR 字段映射(ehr_position)
|
||
# 3. 管理后台手动分配(manual)
|
||
# =============================================================================
|
||
|
||
import logging
|
||
import re
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional, Set
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.role import Role
|
||
from app.models.role_mapping_rule import RoleMappingRule
|
||
from app.models.user_role import UserRole
|
||
from app.services.wecom_service import WecomService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _mask_sensitive_data(value: str, visible_chars: int = 3) -> str:
|
||
"""脱敏处理敏感数据。
|
||
|
||
Args:
|
||
value: 原始值
|
||
visible_chars: 开头保留的字符数
|
||
|
||
Returns:
|
||
str: 脱敏后的值,如 "abc***def"
|
||
"""
|
||
if not value:
|
||
return ""
|
||
if len(value) <= visible_chars:
|
||
return "*" * len(value)
|
||
return f"{value[:visible_chars]}{'*' * (len(value) - visible_chars)}"
|
||
|
||
|
||
class RoleMappingService:
|
||
"""角色映射服务。
|
||
|
||
根据用户的企微标签、eHR岗位等信息,自动映射角色。
|
||
"""
|
||
|
||
def __init__(self, db: AsyncSession, wecom_service: Optional[WecomService] = None):
|
||
"""初始化角色映射服务。
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
wecom_service: 企微API服务(可选,用于获取用户标签)
|
||
"""
|
||
self.db = db
|
||
self.wecom_service = wecom_service
|
||
|
||
async def get_user_roles(self, employee_id: str) -> List[str]:
|
||
"""获取用户的角色列表。
|
||
|
||
查询 user_roles 表,返回用户拥有的角色标识列表。
|
||
|
||
Args:
|
||
employee_id: 企微 UserID
|
||
|
||
Returns:
|
||
List[str]: 角色标识列表(如 ["user", "agent"])
|
||
"""
|
||
stmt = (
|
||
select(Role.name)
|
||
.join(UserRole, Role.id == UserRole.role_id)
|
||
.where(UserRole.employee_id == employee_id)
|
||
.where(
|
||
# 过滤已过期的角色
|
||
(UserRole.expires_at.is_(None)) | (UserRole.expires_at > datetime.now())
|
||
)
|
||
)
|
||
result = await self.db.execute(stmt)
|
||
roles = [row[0] for row in result.all()]
|
||
|
||
# 如果没有角色,添加默认的 user 角色
|
||
if not roles:
|
||
roles = ["user"]
|
||
|
||
return roles
|
||
|
||
async def sync_user_roles(
|
||
self,
|
||
employee_id: str,
|
||
wecom_tags: Optional[List[str]] = None,
|
||
ehr_position: Optional[str] = None,
|
||
) -> List[str]:
|
||
"""同步用户角色。
|
||
|
||
根据企微标签和eHR岗位,自动分配或撤销角色。
|
||
|
||
Args:
|
||
employee_id: 企微 UserID
|
||
wecom_tags: 企微标签列表(可选)
|
||
ehr_position: eHR岗位(可选)
|
||
|
||
Returns:
|
||
List[str]: 同步后的角色列表
|
||
"""
|
||
# 1. 获取当前角色
|
||
current_roles = await self.get_user_roles(employee_id)
|
||
|
||
# 2. 获取映射规则
|
||
mapping_rules = await self._get_active_mapping_rules()
|
||
|
||
# 3. 根据规则确定应该拥有的角色
|
||
should_have_roles: Set[str] = {"user"} # 所有人都有 user 角色
|
||
|
||
for rule in mapping_rules:
|
||
if rule.source_type == "wecom_tag" and wecom_tags:
|
||
# 检查标签是否匹配
|
||
if rule.source_value in wecom_tags:
|
||
role_name = await self._get_role_name_by_id(rule.role_id)
|
||
if role_name:
|
||
should_have_roles.add(role_name)
|
||
|
||
elif rule.source_type == "ehr_position" and ehr_position:
|
||
# 检查岗位关键词是否匹配
|
||
if rule.source_value in ehr_position:
|
||
role_name = await self._get_role_name_by_id(rule.role_id)
|
||
if role_name:
|
||
should_have_roles.add(role_name)
|
||
|
||
# 4. 计算需要添加和删除的角色
|
||
current_set = set(current_roles)
|
||
to_add = should_have_roles - current_set
|
||
to_remove = current_set - should_have_roles - {"user"} # 不删除 user 角色
|
||
|
||
# 5. 添加新角色
|
||
for role_name in to_add:
|
||
await self._add_role(employee_id, role_name, source="tag")
|
||
|
||
# 6. 撤销不再需要的角色(仅撤销自动分配的)
|
||
for role_name in to_remove:
|
||
await self._remove_auto_role(employee_id, role_name)
|
||
|
||
# 7. 返回同步后的角色列表
|
||
return await self.get_user_roles(employee_id)
|
||
|
||
async def _get_active_mapping_rules(self) -> List[RoleMappingRule]:
|
||
"""获取所有启用的映射规则。
|
||
|
||
Returns:
|
||
List[RoleMappingRule]: 映射规则列表
|
||
"""
|
||
stmt = (
|
||
select(RoleMappingRule)
|
||
.where(RoleMappingRule.is_active == True)
|
||
.order_by(RoleMappingRule.priority.desc())
|
||
)
|
||
result = await self.db.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
async def _get_role_name_by_id(self, role_id: str) -> Optional[str]:
|
||
"""根据角色ID获取角色名称。
|
||
|
||
Args:
|
||
role_id: 角色ID
|
||
|
||
Returns:
|
||
Optional[str]: 角色名称,如果不存在返回 None
|
||
"""
|
||
stmt = select(Role.name).where(Role.id == role_id)
|
||
result = await self.db.execute(stmt)
|
||
row = result.first()
|
||
return row[0] if row else None
|
||
|
||
async def _add_role(self, employee_id: str, role_name: str, source: str) -> None:
|
||
"""为用户添加角色。
|
||
|
||
Args:
|
||
employee_id: 企微 UserID
|
||
role_name: 角色标识
|
||
source: 角色来源(auto/tag/ehr/manual)
|
||
"""
|
||
# 查询角色
|
||
stmt = select(Role).where(Role.name == role_name)
|
||
result = await self.db.execute(stmt)
|
||
role = result.scalars().first()
|
||
|
||
if not role:
|
||
logger.warning(f"角色 {role_name} 不存在,跳过添加")
|
||
return
|
||
|
||
# 检查是否已存在
|
||
existing_stmt = select(UserRole).where(
|
||
UserRole.employee_id == employee_id,
|
||
UserRole.role_id == role.id,
|
||
)
|
||
existing_result = await self.db.execute(existing_stmt)
|
||
existing = existing_result.scalars().first()
|
||
|
||
if existing:
|
||
logger.debug(f"用户 {_mask_sensitive_data(employee_id)} 已拥有角色 {role_name},跳过添加")
|
||
return
|
||
|
||
# 创建用户角色关联
|
||
user_role = UserRole(
|
||
employee_id=employee_id,
|
||
role_id=role.id,
|
||
source=source,
|
||
)
|
||
self.db.add(user_role)
|
||
await self.db.commit()
|
||
|
||
logger.info(f"为用户 {_mask_sensitive_data(employee_id)} 添加角色 {role_name}(来源:{source})")
|
||
|
||
async def _remove_auto_role(self, employee_id: str, role_name: str) -> None:
|
||
"""撤销用户的自动分配角色。
|
||
|
||
仅撤销 source 为 auto/tag/ehr 的角色,不撤销手动分配的角色。
|
||
|
||
Args:
|
||
employee_id: 企微 UserID
|
||
role_name: 角色标识
|
||
"""
|
||
# 查询角色
|
||
stmt = select(Role).where(Role.name == role_name)
|
||
result = await self.db.execute(stmt)
|
||
role = result.scalars().first()
|
||
|
||
if not role:
|
||
return
|
||
|
||
# 查询用户角色关联(仅自动分配的)
|
||
user_role_stmt = select(UserRole).where(
|
||
UserRole.employee_id == employee_id,
|
||
UserRole.role_id == role.id,
|
||
UserRole.source.in_(["auto", "tag", "ehr"]), # 仅自动分配的
|
||
)
|
||
user_role_result = await self.db.execute(user_role_stmt)
|
||
user_role = user_role_result.scalars().first()
|
||
|
||
if user_role:
|
||
await self.db.delete(user_role)
|
||
await self.db.commit()
|
||
logger.info(f"撤销用户 {_mask_sensitive_data(employee_id)} 的自动分配角色 {role_name}")
|
||
|
||
async def get_wecom_user_tags(self, user_id: str) -> List[str]:
|
||
"""获取用户的企微标签列表。
|
||
|
||
调用企微通讯录API获取用户的标签ID列表,然后查询标签名称。
|
||
|
||
Args:
|
||
user_id: 企微 UserID
|
||
|
||
Returns:
|
||
List[str]: 标签名称列表
|
||
"""
|
||
if not self.wecom_service:
|
||
logger.warning("WecomService 未初始化,无法获取企微标签")
|
||
return []
|
||
|
||
try:
|
||
# 获取用户信息(包含 tagids)
|
||
user_info = await self.wecom_service.get_user_info(user_id)
|
||
tag_ids = user_info.get("tagids", [])
|
||
|
||
if not tag_ids:
|
||
return []
|
||
|
||
# 查询标签名称
|
||
tag_names = await self._get_tag_names_by_ids(tag_ids)
|
||
return tag_names
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户企微标签失败: user_id={user_id}, error={e}")
|
||
return []
|
||
|
||
# 标签名称验证常量
|
||
MAX_TAG_NAME_LENGTH = 50 # 最大标签名称长度
|
||
TAG_NAME_FORBIDDEN_CHARS = "<>'\"&;\\|%$#@`" # 禁止的特殊字符
|
||
|
||
def _validate_tag_name(self, tag_name: str) -> bool:
|
||
"""验证标签名称是否安全。
|
||
|
||
Args:
|
||
tag_name: 标签名称
|
||
|
||
Returns:
|
||
bool: 是否有效
|
||
"""
|
||
# 检查长度
|
||
if not tag_name or len(tag_name) > self.MAX_TAG_NAME_LENGTH:
|
||
return False
|
||
|
||
# 检查禁止字符
|
||
for char in self.TAG_NAME_FORBIDDEN_CHARS:
|
||
if char in tag_name:
|
||
return False
|
||
|
||
return True
|
||
|
||
async def _get_tag_names_by_ids(self, tag_ids: List[int]) -> List[str]:
|
||
"""根据标签ID列表获取标签名称。
|
||
|
||
调用企微标签管理API获取标签名称,并进行安全验证。
|
||
|
||
Args:
|
||
tag_ids: 标签ID列表
|
||
|
||
Returns:
|
||
List[str]: 验证后的标签名称列表
|
||
"""
|
||
if not self.wecom_service:
|
||
return []
|
||
|
||
try:
|
||
access_token = await self.wecom_service.get_access_token()
|
||
import httpx
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
# 获取标签列表(企微API)
|
||
url = "https://qyapi.weixin.qq.com/cgi-bin/tag/list"
|
||
params = {"access_token": access_token}
|
||
response = await client.get(url, params=params)
|
||
result = response.json()
|
||
|
||
if result.get("errcode", 0) != 0:
|
||
logger.error(f"获取标签列表失败: {result}")
|
||
return []
|
||
|
||
# 构建 tag_id -> tag_name 映射(带安全验证)
|
||
tag_map = {}
|
||
for tag in result.get("taglist", []):
|
||
tag_name = tag.get("tagname", "")
|
||
# 安全验证:过滤不安全的标签名称
|
||
if self._validate_tag_name(tag_name):
|
||
tag_map[tag["tagid"]] = tag_name
|
||
|
||
# 返回匹配的标签名称
|
||
valid_tag_names = [
|
||
tag_map[tag_id]
|
||
for tag_id in tag_ids
|
||
if tag_id in tag_map
|
||
]
|
||
|
||
# 记录获取到的标签数量(非敏感信息)
|
||
logger.debug(f"获取到 {len(valid_tag_names)} 个有效标签")
|
||
return valid_tag_names
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取标签名称失败: {e}")
|
||
return []
|