chore: initial baseline with P0-safety .gitignore

This commit is contained in:
Simon
2026-06-14 16:49:18 +08:00
commit 63262292d7
510 changed files with 146008 additions and 0 deletions
+6
View File
@@ -0,0 +1,6 @@
# =============================================================================
# 企微IT智能服务台 — 外部系统集成模块包
# =============================================================================
# 说明:各外部系统的 API 客户端、数据模型、异常定义等
# 当前已实现:火绒终端安全
#
@@ -0,0 +1,3 @@
# =============================================================================
# 企微IT智能服务台 — 火绒终端安全集成模块包
# =============================================================================
+658
View File
@@ -0,0 +1,658 @@
# =============================================================================
# 企微IT智能服务台 — 火绒终端安全 API 客户端
# =============================================================================
# 说明:封装火绒API的签名、请求、响应处理
# 核心功能:
# 1. HRESS 签名实现(Authorization Header 方式)
# 2. 统一请求封装(超时、重试、异常处理)
# 3. P0 接口:终端列表 _list / 终端详情 _info2 / 高危漏洞 _leak / 病毒事件 _virus_events
# 4. P1 接口:终端隔离/解除 _create(netctrl) / 快速扫描 / 在线终端查询
# 签名算法(来自火绒官方API文档 v1):
# Authorization = "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature
# Signature = urlencode(base64(hmac-sha1(AccessKeySecret,
# AccessKeyId + "\n" + Expires + "\n" + HTTP-METHOD + "\n"
# + Content-MD5 + "\n" + CanonicalizedResource)))
# CanonicalizedResource = API路径(无前导/),含排序后的查询参数
# Content-MD5 = base64(md5_digest(body_bytes)) — RFC2616
# 使用方式:
# client = HuorongClient(access_key_id="...", access_key_secret="...", base_url="...")
# terminals = await client.list_terminals()
# =============================================================================
import base64
import hashlib
import hmac
import json
import logging
import time
from typing import Any, Dict, List, Optional
from urllib.parse import quote
import httpx
from .exceptions import (
HuorongApiError,
HuorongAuthError,
HuorongConnectionError,
HuorongError,
)
from .models import (
HuorongApiResponse,
TerminalBasicInfo,
TerminalDetailV2,
TerminalLeakInfo,
VirusEventStats,
VirusHandleResult,
)
logger = logging.getLogger(__name__)
# 默认请求超时(秒)— 火绒内网响应通常在1秒内,3秒足够兜底
# 注意:_virus_events 查询全部终端时可能较慢,需要更长超时
DEFAULT_TIMEOUT = 10.0
# 默认分页大小
DEFAULT_PAGE_SIZE = 20
# 签名有效期(秒)— 请求签名中的 Expires 字段
SIGN_EXPIRES_SECONDS = 300
class HuorongClient:
"""火绒终端安全 API 客户端。
封装了火绒API的签名认证、请求发送和响应解析。
所有方法均为异步(async),使用 httpx.AsyncClient 发送请求。
签名方式:HRESS Authorization Header
参考:火绒终端安全管理系统API说明文档 v1
Attributes:
access_key_id: 火绒 AccessKey ID(控制中心显示为 Secret ID
access_key_secret: 火绒 AccessKey Secret(控制中心显示为 Secret Key
base_url: 火绒API内网地址(如 http://huorong.oa.servyou-it.com:8080
timeout: 请求超时秒数
"""
def __init__(
self,
access_key_id: str,
access_key_secret: str,
base_url: str,
timeout: float = DEFAULT_TIMEOUT,
):
"""初始化火绒API客户端。
Args:
access_key_id: 火绒 AccessKey ID(控制中心显示为 Secret ID
access_key_secret: 火绒 AccessKey Secret(控制中心显示为 Secret Key
base_url: 火绒API内网地址(不含尾部斜杠)
timeout: 请求超时秒数,默认3秒
"""
self.access_key_id = access_key_id
self.access_key_secret = access_key_secret
# 确保 base_url 不以 / 结尾,拼接路径时统一加 /
self.base_url = base_url.rstrip("/")
self.timeout = timeout
# ======================================================================
# 签名实现(HRESS Authorization Header 方式)
# ======================================================================
def _compute_content_md5(self, body_bytes: bytes) -> str:
"""计算请求体的 Content-MD5RFC2616)。
算法步骤:
1. 计算请求体的 MD5 二进制摘要(128位)
2. 对二进制摘要进行 base64 编码
注意:不是对32位十六进制字符串编码,而是对原始二进制摘要编码。
Args:
body_bytes: 请求体的字节内容
Returns:
str: base64 编码的 MD5 摘要
"""
md5_digest = hashlib.md5(body_bytes).digest()
return base64.b64encode(md5_digest).decode("utf-8")
def _build_canonicalized_resource(self, path: str) -> str:
"""构建 CanonicalizedResource。
根据火绒API文档:
1. 将 CanonicalizedResource 置为空字符串
2. 设置要访问的资源路径(去掉前导 /),如 "api/clnts/_list"
3. 如果请求包含子资源(查询参数),按字典序排列,
以 & 为分隔符生成子资源字符串,末尾添加 ? 和子资源字符串
示例:
- /api/clnts/_list → "api/clnts/_list"
- /api/group/_info?group_id=1 → "api/group/_info?group_id=1"
Args:
path: 请求路径(如 /api/clnts/_list
Returns:
str: CanonicalizedResource 字符串
"""
# 去掉前导 /
return path.lstrip("/")
def _sign_request(
self,
method: str,
path: str,
body_bytes: bytes = b"",
) -> Dict[str, str]:
"""生成火绒API请求签名(HRESS Authorization Header 方式)。
签名算法(来自火绒官方API文档):
┌─────────────────────────────────────────────────────────────────┐
│ Authorization = "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature │
│ │
│ Signature = urlencode(base64(hmac-sha1(AccessKeySecret, │
│ AccessKeyId + "\\n"
│ + Expires + "\\n"
│ + HTTP-METHOD + "\\n"
│ + Content-MD5 + "\\n"
│ + CanonicalizedResource))) │
└─────────────────────────────────────────────────────────────────┘
其中:
- AccessKeyId: 标识用户身份
- Expires: Unix 时间戳,签名过期时间(当前时间 + 300秒)
- HTTP-METHOD: POST(火绒API统一使用 POST
- Content-MD5: 请求体的 RFC2616 MD5base64编码)
- CanonicalizedResource: API资源路径(去掉前导/)
Args:
method: HTTP方法(POST
path: 请求路径(如 /api/clnts/_list
body_bytes: 请求体字节内容
Returns:
Dict[str, str]: 包含 Authorization 和 Content-Type 的 Header 字典
"""
# 1. 计算过期时间(Unix时间戳)
expires = str(int(time.time()) + SIGN_EXPIRES_SECONDS)
# 2. 计算 Content-MD5RFC2616: MD5 二进制摘要 → base64
content_md5 = self._compute_content_md5(body_bytes) if body_bytes else ""
# 3. 构建 CanonicalizedResource(去掉前导 /
canonicalized_resource = self._build_canonicalized_resource(path)
# 4. 构建签名字符串
string_to_sign = (
self.access_key_id + "\n"
+ expires + "\n"
+ method + "\n"
+ content_md5 + "\n"
+ canonicalized_resource
)
# 5. HMAC-SHA1 签名 → base64 编码 → URL 编码
signature_raw = hmac.new(
self.access_key_secret.encode("utf-8"), # 密钥
string_to_sign.encode("utf-8"), # 待签名字符串
hashlib.sha1, # 算法
).digest()
signature_b64 = base64.b64encode(signature_raw).decode("utf-8")
signature_encoded = quote(signature_b64, safe="")
# 6. 拼接 Authorization Header
# 格式: "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature
authorization = f"HRESS{self.access_key_id}:{expires}:{signature_encoded}"
return {
"Authorization": authorization,
"Content-Type": "application/json; charset=utf-8",
}
# ======================================================================
# 通用请求方法
# ======================================================================
async def _request(
self,
path: str,
body: Optional[Dict[str, Any]] = None,
) -> HuorongApiResponse:
"""发送签名请求到火绒API。
统一处理:
1. HRESS 签名 Authorization Header 生成
2. HTTP请求发送(POST,超时控制)
3. 响应解析和错误码处理
4. 异常分类(认证/连接/API业务错误)
Args:
path: API路径(如 /api/clnts/_list
body: 请求体字典(可选)
Returns:
HuorongApiResponse: 火绒API响应
Raises:
HuorongConnectionError: 网络不通或超时
HuorongAuthError: 签名验证失败
HuorongApiError: 火绒API返回业务错误
"""
# 构建完整URL
url = f"{self.base_url}{path}"
# 序列化请求体为字节(签名基于字节内容)
body_bytes = json.dumps(body, separators=(",", ":")).encode("utf-8") if body else b"{}"
# 生成签名 Header
headers = self._sign_request("POST", path, body_bytes)
try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
logger.debug(
f"火绒API请求: POST {url}\n"
f" AccessKeyID: {self.access_key_id}\n"
f" Path: {path}\n"
f" Body: {body_bytes[:200].decode('utf-8', errors='replace')}\n"
f" Authorization: {headers.get('Authorization', 'N/A')[:60]}..."
)
response = await client.post(url, headers=headers, content=body_bytes)
# HTTP层面错误
if response.status_code == 401:
raise HuorongAuthError()
if response.status_code != 200:
raise HuorongApiError(
code=response.status_code,
message=f"HTTP {response.status_code}: {response.text[:200]}",
)
# 解析JSON响应
resp_data = response.json()
api_resp = HuorongApiResponse(**resp_data)
# 火绒业务错误码处理
# 官方文档定义的错误码:
# - errno=0: 成功
# - errno=1: 认证失败
# - errno=2: 参数错误
# - errno=3: 服务器内部错误
# - errno=4: API未授权
if api_resp.errcode != 0:
if api_resp.errcode == 1 or api_resp.errcode in (401, 403):
raise HuorongAuthError(f"认证/权限失败: {api_resp.errmsg}")
if api_resp.errcode == 4:
raise HuorongApiError(
code=api_resp.errcode,
message=f"API未授权: {api_resp.errmsg}",
)
raise HuorongApiError(
code=api_resp.errcode,
message=api_resp.errmsg,
)
return api_resp
except httpx.TimeoutException:
raise HuorongConnectionError(f"火绒API请求超时({self.timeout}秒): {url}")
except httpx.ConnectError:
raise HuorongConnectionError(f"无法连接火绒服务器: {url}")
except (HuorongAuthError, HuorongApiError, HuorongConnectionError):
# 已分类异常,直接向上抛出
raise
except Exception as e:
# 未预期异常,包装为通用错误
logger.error(f"火绒API未预期异常: {type(e).__name__}: {e}")
raise HuorongError(code=-1, message=f"火绒API调用异常: {e}")
# ======================================================================
# P0 接口:查询能力
# ======================================================================
async def list_terminals(
self,
group_id: Optional[str] = None,
page: int = 1,
per_page: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""查询终端基本信息列表。
火绒API: POST /api/clnts/_list
官方参数: limit(每页条数, 默认15, 最大200) + offset(起始索引, 默认0)
本方法将 page/per_page 转换为 limit/offset,保持外部接口一致。
Args:
group_id: 分组ID(可选,不传则查全部分组)
page: 页码(从1开始,内部转换为offset)
per_page: 每页条数(内部转换为limit)
Returns:
Dict: 包含 total(总数) 和 items(TerminalBasicInfo列表)
"""
# 火绒API使用 limit/offset 分页,不是 page/per_page
limit = min(per_page, 200) # 火绒限制最大200
offset = (page - 1) * limit
body: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if group_id:
body["group_id"] = int(group_id)
resp = await self._request("/api/clnts/_list", body)
# 解析响应数据
data = resp.data or {}
raw_items = data.get("list", [])
total = data.get("total", 0)
items = [TerminalBasicInfo(**item) for item in raw_items]
return {
"total": total,
"items": items,
}
async def get_terminal_detail(
self,
client_id: str,
optional_fields: Optional[List[str]] = None,
) -> TerminalDetailV2:
"""获取终端详细信息v2。
火绒API: POST /api/clnts/_info2
用途:获取终端硬件/软件/资产/网络配置等详细信息
Args:
client_id: 终端唯一ID
optional_fields: 需要返回的可选信息块
可选值: hardware, software, assets, netconf
默认全部返回
Returns:
TerminalDetailV2: 终端详细信息
"""
if optional_fields is None:
optional_fields = ["hardware", "software", "assets", "netconf"]
body = {
"client_id": client_id,
"optional_fields": optional_fields,
}
resp = await self._request("/api/clnts/_info2", body)
data = resp.data or {}
return TerminalDetailV2(**data)
async def list_terminal_leaks(
self,
group_id: Optional[str] = None,
page: int = 1,
per_page: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""查询存在高危漏洞未修复的终端。
火绒API: POST /api/clnts/_leak
官方参数: limit(每页条数, 默认15, 最大200) + offset(起始索引, 默认0)
说明:返回的是"存在高危漏洞的终端列表",不是漏洞详情。
每条记录是终端信息,字段名与 _list 不同:
- cid (非 client_id)
- hostname (非 computer_name)
- ip_addr (非 local_ip)
- stat (1=离线,2=在线,3=异常,非 is_online 布尔值)
外层有 all_client(终端总数)和 risk_client(高危终端数)统计。
Args:
group_id: 分组ID(可选,不传则查全部分组)
page: 页码(从1开始,内部转换为offset)
per_page: 每页条数(内部转换为limit)
Returns:
Dict: 包含 total(高危终端总数), risk_client(高危终端数),
all_client(全部终端数) 和 items(TerminalLeakInfo列表)
"""
# 火绒API使用 limit/offset 分页
limit = min(per_page, 200)
offset = (page - 1) * limit
body: Dict[str, Any] = {
"limit": limit,
"offset": offset,
}
if group_id:
body["group_id"] = int(group_id)
resp = await self._request("/api/clnts/_leak", body)
# 解析响应数据
data = resp.data or {}
raw_items = data.get("list", [])
# _leak 不返回 total,但有 all_client 和 risk_client 统计
all_client = data.get("all_client", 0)
risk_client = data.get("risk_client", 0)
items = [TerminalLeakInfo(**item) for item in raw_items]
return {
"total": risk_client, # 高危终端总数 = risk_client
"all_client": all_client, # 全部终端数
"risk_client": risk_client, # 高危终端数
"items": items,
}
async def get_virus_events(
self,
client_id: Optional[str] = None,
group_id: Optional[str] = None,
query_type: int = 2,
begin_time: Optional[int] = None,
end_time: Optional[int] = None,
page: int = 1,
per_page: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""查询终端病毒事件统计。
火绒API: POST /api/clnts/_virus_events
官方参数:
- type: 查询类型(必填)
0=使用终端唯一标识查询(client_id字段必填)
1=使用分组ID查询(group_id字段必填)
2=查询全部终端日志(client_id和group_id字段可忽略)
- client_id: 终端唯一标识
- group_id: 分组ID
- begin_time/end_time: 日志范围时间(Unix时间戳,默认全部时间)
- limit/offset: 分页参数
说明:返回终端维度的病毒日志统计,含 count(总数) 和
result{success/fail/ignored/trusted}(处理结果明细)。
Args:
client_id: 终端唯一IDtype=0时必填)
group_id: 分组IDtype=1时必填)
query_type: 查询类型,默认2(查全部)
begin_time: 日志开始时间(Unix时间戳,可选)
end_time: 日志结束时间(Unix时间戳,可选)
page: 页码(从1开始,内部转换为offset)
per_page: 每页条数(内部转换为limit)
Returns:
Dict: 包含 total(总数) 和 items(VirusEventStats列表)
"""
# 火绒API使用 limit/offset 分页
limit = min(per_page, 200)
offset = (page - 1) * limit
body: Dict[str, Any] = {
"type": query_type,
"limit": limit,
"offset": offset,
}
# 根据查询类型添加可选参数
if query_type == 0 and client_id:
body["client_id"] = client_id
if query_type in (0, 1) and group_id:
body["group_id"] = int(group_id)
# 时间范围过滤
if begin_time:
body["begin_time"] = begin_time
if end_time:
body["end_time"] = end_time
resp = await self._request("/api/clnts/_virus_events", body)
data = resp.data or {}
raw_items = data.get("list", [])
total = data.get("total", 0)
items = [VirusEventStats(**item) for item in raw_items]
return {
"total": total,
"items": items,
}
# ======================================================================
# P1 接口:控制能力
# ======================================================================
async def isolate_terminal(
self,
client_ids: List[str],
) -> Dict[str, Any]:
"""隔离终端(断网)。
火绒API: POST /api/task/_create (type=netctrl, net_isolation=true)
安全等级: 🔴 高危操作,调用方必须确保:
1. 仅 admin 角色可调用
2. 已完成二次确认
3. 已记录操作原因
Args:
client_ids: 目标终端ID列表
Returns:
Dict: 火绒API响应的data部分
"""
body = {
"type": "netctrl",
"net_isolation": True,
"clients": client_ids,
}
resp = await self._request("/api/task/_create", body)
logger.warning(f"火绒终端隔离操作: client_ids={client_ids}")
return resp.data or {}
async def unisolate_terminal(
self,
client_ids: List[str],
) -> Dict[str, Any]:
"""解除终端隔离(恢复网络)。
火绒API: POST /api/task/_create (type=netctrl, net_isolation=false)
Args:
client_ids: 目标终端ID列表
Returns:
Dict: 火绒API响应的data部分
"""
body = {
"type": "netctrl",
"net_isolation": False,
"clients": client_ids,
}
resp = await self._request("/api/task/_create", body)
logger.info(f"火绒终端解除隔离: client_ids={client_ids}")
return resp.data or {}
async def create_scan_task(
self,
client_ids: List[str],
scan_type: str = "quick_scan",
) -> Dict[str, Any]:
"""创建终端扫描任务。
火绒API: POST /api/task/_create
扫描类型: quick_scan(快速扫描) / full_scan(全盘扫描) / custom_scan(自定义扫描)
Args:
client_ids: 目标终端ID列表
scan_type: 扫描类型,默认快速扫描
Returns:
Dict: 火绒API响应的data部分
"""
body = {
"type": scan_type,
"clients": client_ids,
}
resp = await self._request("/api/task/_create", body)
logger.info(f"火绒终端扫描任务: type={scan_type}, client_ids={client_ids}")
return resp.data or {}
async def send_notification(
self,
client_ids: List[str],
content: str,
) -> Dict[str, Any]:
"""向终端发送通知。
火绒API: POST /api/task/_create (type=message)
Args:
client_ids: 目标终端ID列表
content: 通知内容
Returns:
Dict: 火绒API响应的data部分
"""
body = {
"type": "message",
"clients": client_ids,
"content": content,
}
resp = await self._request("/api/task/_create", body)
logger.info(f"火绒终端通知: client_ids={client_ids}, content={content[:50]}")
return resp.data or {}
# ======================================================================
# 测试连接
# ======================================================================
async def test_connection(self) -> Dict[str, Any]:
"""测试火绒API连接是否正常。
使用 _list 接口(page=1, per_page=1)进行轻量级连接测试,
验证签名是否正确、网络是否可达。
Returns:
Dict: 包含 success(bool) 和 message(str)
"""
try:
result = await self.list_terminals(page=1, per_page=1)
return {
"success": True,
"message": f"连接成功,共 {result.get('total', 0)} 个终端",
"total_terminals": result.get("total", 0),
}
except HuorongAuthError as e:
return {
"success": False,
"message": f"认证失败: {e.message}",
}
except HuorongConnectionError as e:
return {
"success": False,
"message": f"连接失败: {e.message}",
}
except HuorongError as e:
return {
"success": False,
"message": f"测试失败: {e.message}",
}
@@ -0,0 +1,87 @@
# =============================================================================
# 企微IT智能服务台 — 火绒集成配置管理
# =============================================================================
# 说明:从系统配置表(system_configs)读取火绒 AccessKey/Secret/BaseUrl
# 构建火绒API客户端实例
# =============================================================================
import logging
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.system_config import SystemConfig
from .client import HuorongClient
from .exceptions import HuorongConfigError
logger = logging.getLogger(__name__)
# 火绒配置在 system_configs 表中的 key 前缀
HUORONG_CONFIG_PREFIX = "integration_huorong_"
async def get_huorong_client(db: AsyncSession) -> HuorongClient:
"""从系统配置表构建火绒API客户端。
读取 integration_huorong_ 前缀的配置项,构建 HuorongClient 实例。
如果任何必填配置缺失,抛出 HuorongConfigError。
Args:
db: 数据库会话
Returns:
HuorongClient: 已配置的火绒API客户端实例
Raises:
HuorongConfigError: AccessKey ID/Secret/Base URL 任一缺失
"""
# 读取三个必填配置
result = await db.execute(
select(SystemConfig).where(
SystemConfig.config_key.startswith(HUORONG_CONFIG_PREFIX)
)
)
configs = list(result.scalars().all())
# 构建 key→value 映射
config_map = {cfg.config_key: cfg.config_value for cfg in configs}
access_key_id = config_map.get(f"{HUORONG_CONFIG_PREFIX}access_key_id", "")
access_key_secret = config_map.get(f"{HUORONG_CONFIG_PREFIX}access_key_secret", "")
base_url = config_map.get(f"{HUORONG_CONFIG_PREFIX}base_url", "")
# 校验必填项
if not access_key_id or not access_key_secret or not base_url:
missing = []
if not access_key_id:
missing.append("AccessKey ID")
if not access_key_secret:
missing.append("AccessKey Secret")
if not base_url:
missing.append("Base URL")
raise HuorongConfigError(
f"火绒集成配置不完整,缺失: {', '.join(missing)},请先在管理后台完成配置"
)
return HuorongClient(
access_key_id=access_key_id,
access_key_secret=access_key_secret,
base_url=base_url,
)
async def is_huorong_configured(db: AsyncSession) -> bool:
"""检查火绒集成是否已完整配置。
Args:
db: 数据库会话
Returns:
bool: 三项配置均存在且非空时返回 True
"""
try:
client = await get_huorong_client(db)
return bool(client.access_key_id and client.access_key_secret and client.base_url)
except HuorongConfigError:
return False
@@ -0,0 +1,63 @@
# =============================================================================
# 企微IT智能服务台 — 火绒集成自定义异常
# =============================================================================
# 说明:火绒API调用中可能抛出的各种异常类型
# 包含:认证错误、连接超时、API错误码等
# =============================================================================
class HuorongError(Exception):
"""火绒集成基础异常。
所有火绒相关异常的父类,便于统一捕获处理。
Attributes:
code: 错误码(火绒API返回的errcode,或自定义错误码)
message: 错误描述
"""
def __init__(self, code: int = -1, message: str = "火绒API调用失败"):
self.code = code
self.message = message
super().__init__(f"[HuorongError:{code}] {message}")
class HuorongAuthError(HuorongError):
"""火绒认证失败异常。
场景:AccessKey ID/Secret 无效、签名校验失败、权限不足
火绒API返回 errcode=401 或签名相关错误时抛出。
"""
def __init__(self, message: str = "火绒API认证失败,请检查AccessKey配置"):
super().__init__(code=401, message=message)
class HuorongConnectionError(HuorongError):
"""火绒连接失败异常。
场景:内网地址不通、超时、DNS解析失败
"""
def __init__(self, message: str = "无法连接火绒服务器,请检查网络和Base URL配置"):
super().__init__(code=502, message=message)
class HuorongConfigError(HuorongError):
"""火绒配置缺失异常。
场景:AccessKey ID/Secret/Base URL 未在系统配置中设置
"""
def __init__(self, message: str = "火绒集成未配置,请先在管理后台设置AccessKey和Base URL"):
super().__init__(code=400, message=message)
class HuorongApiError(HuorongError):
"""火绒API业务错误。
场景:火绒API返回非0 errcode(如参数错误、终端不存在等)
"""
def __init__(self, code: int, message: str):
super().__init__(code=code, message=message)
+373
View File
@@ -0,0 +1,373 @@
# =============================================================================
# 企微IT智能服务台 — 火绒集成数据模型
# =============================================================================
# 说明:火绒API请求/响应的 Pydantic 数据模型
# 包含:终端信息、漏洞信息、病毒事件、任务下发等
# =============================================================================
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, model_validator
# ==========================================================================
# 通用响应模型
# ==========================================================================
class HuorongApiResponse(BaseModel):
"""火绒API统一响应模型。
火绒所有API返回格式一致(官方API文档 v1):
成功时: { "errno": 0, "errmsg": "", "data": { ... } }
失败时: { "errno": 1, "errmsg": "Authentication failed" }
官方错误码定义:
- errno=0: 成功
- errno=1: 认证失败
- errno=2: 参数错误
- errno=3: 服务器内部错误
- errno=4: API未授权
注意:火绒API始终使用 errno(不是 errcode)。
使用 model_validator 在验证前将 errno 归一化为 errcode
保持内部代码统一使用 errcode 字段。
Attributes:
errcode: 错误码,0表示成功(从 errno 归一化而来)
errmsg: 错误描述(成功时为空字符串)
data: 业务数据(成功时非None)
"""
@model_validator(mode='before')
@classmethod
def normalize_error_fields(cls, data: Any) -> Any:
"""将火绒API返回的 errno 字段归一化为 errcode。
火绒API在认证失败等错误场景下返回 errno 而非 errcode
此验证器在 Pydantic 字段校验前将 errno 转换为 errcode
统一后续处理逻辑。
Args:
data: 原始输入数据(通常为dict)
Returns:
归一化后的数据
"""
if isinstance(data, dict) and 'errno' in data and 'errcode' not in data:
data['errcode'] = data.pop('errno')
return data
errcode: int = Field(..., description="错误码,0=成功")
errmsg: str = Field(default="ok", description="错误描述")
data: Optional[Any] = Field(default=None, description="业务数据")
# ==========================================================================
# 终端基本信息 — /api/clnts/_list 返回
# ==========================================================================
class TerminalBasicInfo(BaseModel):
"""终端基本信息(_list 接口返回的每条记录)。
字段名严格按照火绒API文档实际返回值定义。
注意:API返回的字段名与之前猜测不同,已根据官方文档修正。
Attributes:
id: 内部数据库ID
client_id: 终端唯一ID(40位十六进制字符串,用于所有任务下发)
client_name: 客户端名称
computer_name: 计算机名
local_ip: 本地IP
connect_ip: 连接IP(客户端连接控制中心使用的IP)
mac: MAC地址
group_id: 分组ID
os_version: 操作系统版本
version: 火绒客户端版本
definitions: 病毒库更新时间
is_online: 在线状态
last_connect_time: 最后连接时间(Unix时间戳)
last_seen_time: 最后可见时间(Unix时间戳)
first_appear_time: 首次出现时间(Unix时间戳)
"""
id: Optional[int] = Field(default=None, description="内部数据库ID")
client_id: str = Field(..., description="终端唯一ID")
client_name: str = Field(default="", description="客户端名称")
computer_name: str = Field(default="", description="计算机名")
local_ip: str = Field(default="", description="本地IP")
connect_ip: str = Field(default="", description="连接IP")
mac: str = Field(default="", description="MAC地址")
group_id: Optional[Any] = Field(default=None, description="分组IDint或str")
os_version: str = Field(default="", description="操作系统版本")
version: str = Field(default="", description="火绒客户端版本")
definitions: str = Field(default="", description="病毒库更新时间")
is_online: bool = Field(default=False, description="在线状态")
last_connect_time: Optional[int] = Field(default=None, description="最后连接时间")
last_seen_time: Optional[int] = Field(default=None, description="最后可见时间")
first_appear_time: Optional[int] = Field(default=None, description="首次出现时间")
class TerminalListRequest(BaseModel):
"""终端列表查询请求。
Attributes:
group_id: 分组ID(可选,不传则查全部分组)
page: 页码(从1开始)
per_page: 每页条数
"""
group_id: Optional[str] = Field(default=None, description="分组ID")
page: int = Field(default=1, ge=1, description="页码")
per_page: int = Field(default=20, ge=1, le=100, description="每页条数")
# ==========================================================================
# 终端详细信息v2 — /api/clnts/_info2 返回
# ==========================================================================
class HardwareInfo(BaseModel):
"""终端硬件信息。
Attributes:
cpu: CPU信息
memory: 内存信息
disk: 磁盘信息
motherboard: 主板信息
network_card: 网卡信息
"""
cpu: str = Field(default="", description="CPU信息")
memory: str = Field(default="", description="内存信息")
disk: str = Field(default="", description="磁盘信息")
motherboard: str = Field(default="", description="主板信息")
network_card: str = Field(default="", description="网卡信息")
class SoftwareInfo(BaseModel):
"""已安装软件条目。
Attributes:
name: 软件名称
version: 版本号
publisher: 发布者
"""
name: str = Field(default="", description="软件名称")
version: str = Field(default="", description="版本号")
publisher: str = Field(default="", description="发布者")
class AssetInfo(BaseModel):
"""资产信息。
Attributes:
asset_tag: 资产标签
serial_number: 序列号
"""
asset_tag: str = Field(default="", description="资产标签")
serial_number: str = Field(default="", description="序列号")
class NetworkConfig(BaseModel):
"""网络配置信息。
Attributes:
ip: IP地址
gateway: 网关
dns: DNS服务器
adapter_info: 网卡适配器信息
"""
ip: str = Field(default="", description="IP地址")
gateway: str = Field(default="", description="网关")
dns: str = Field(default="", description="DNS服务器")
adapter_info: str = Field(default="", description="网卡适配器信息")
class TerminalDetailV2(BaseModel):
"""终端详细信息v2(_info2 接口返回)。
通过 optional_fields 参数指定需要返回的信息块:
- hardware: 硬件信息
- software: 已安装软件
- assets: 资产信息
- netconf: 网络配置
Attributes:
client_id: 终端唯一ID
computer_name: 计算机名
hardware: 硬件信息(可选)
software: 已安装软件列表(可选)
assets: 资产信息(可选)
netconf: 网络配置(可选)
"""
client_id: str = Field(..., description="终端唯一ID")
computer_name: str = Field(default="", description="计算机名")
hardware: Optional[HardwareInfo] = Field(default=None, description="硬件信息")
software: Optional[List[SoftwareInfo]] = Field(default=None, description="已安装软件")
assets: Optional[AssetInfo] = Field(default=None, description="资产信息")
netconf: Optional[NetworkConfig] = Field(default=None, description="网络配置")
class TerminalDetailRequest(BaseModel):
"""终端详细信息查询请求。
Attributes:
client_id: 终端唯一ID
optional_fields: 需要返回的可选信息块列表
"""
client_id: str = Field(..., description="终端唯一ID")
optional_fields: List[str] = Field(
default_factory=lambda: ["hardware", "software", "assets", "netconf"],
description="可选信息块: hardware/software/assets/netconf",
)
# ==========================================================================
# 漏洞信息 — /api/clnts/_leak 返回
# 说明:_leak 接口返回的是"存在高危漏洞未修复的终端列表",
# 每条记录是终端信息(非漏洞详情),API不返回具体漏洞CVE列表。
# 外层还有 all_client(终端总数)和 risk_client(高危终端数)统计。
# ==========================================================================
class TerminalLeakInfo(BaseModel):
"""存在高危漏洞的终端信息(_leak 接口返回的每条记录)。
注意:_leak 返回的是终端维度数据,不是漏洞维度。
字段名严格按照火绒API文档实际返回值定义。
与 _list 接口的字段名不同!
Attributes:
cid: 终端唯一ID_leak 中叫 cid_list 中叫 client_id
hostname: 计算机名(_leak 中叫 hostname_list 中叫 computer_name
client_name: 终端名称
group_name: 分组名称
group_id: 分组ID
ip_addr: 本地IP_leak 中叫 ip_addr_list 中叫 local_ip
call_ip: 连接IP_leak 中叫 call_ip_list 中叫 connect_ip
mac: MAC地址
osver: 操作系统版本(_leak 中叫 osver_list 中叫 os_version
os_type: 终端类型(如 Windows
prodver: 火绒客户端版本(_leak 中叫 prodver_list 中叫 version
virdb: 病毒库版本(Unix时间戳,_leak 中叫 virdb_list 中叫 definitions
stat: 在线状态码(1=离线, 2=在线, 3=异常,_list 中是 is_online 布尔值)
"""
cid: str = Field(..., description="终端唯一ID")
hostname: str = Field(default="", description="计算机名")
client_name: str = Field(default="", description="终端名称")
group_name: str = Field(default="", description="分组名称")
group_id: Optional[Any] = Field(default=None, description="分组ID")
ip_addr: str = Field(default="", description="本地IP")
call_ip: str = Field(default="", description="连接IP")
mac: str = Field(default="", description="MAC地址")
osver: str = Field(default="", description="操作系统版本")
os_type: str = Field(default="", description="终端类型")
prodver: str = Field(default="", description="火绒客户端版本")
virdb: Optional[Any] = Field(default=None, description="病毒库版本(Unix时间戳)")
stat: int = Field(default=1, description="在线状态码: 1=离线 2=在线 3=异常")
# ==========================================================================
# 病毒事件 — /api/clnts/_virus_events 返回
# 说明:_virus_events 返回终端维度的病毒日志统计,
# 含总数(count)和4种处理结果(result)的明细。
# 请求需指定 type: 0=按client_id查, 1=按group_id查, 2=查全部
# ==========================================================================
class VirusHandleResult(BaseModel):
"""病毒事件处理结果统计。
Attributes:
success: 处理成功数
fail: 处理失败数
ignored: 暂不处理数
trusted: 已信任数
"""
success: int = Field(default=0, description="处理成功数")
fail: int = Field(default=0, description="处理失败数")
ignored: int = Field(default=0, description="暂不处理数")
trusted: int = Field(default=0, description="已信任数")
class VirusEventStats(BaseModel):
"""终端病毒事件统计(_virus_events 接口返回的每条记录)。
字段名严格按照火绒API文档实际返回值定义。
与 _list 接口的字段名基本一致。
Attributes:
group_id: 分组ID
client_id: 终端唯一ID
client_name: 终端名称
computer_name: 计算机名
local_ip: 本地IP
connect_ip: 连接IP
mac: MAC地址
count: 病毒日志总数
result: 处理结果统计(success/fail/ignored/trusted
"""
group_id: Optional[Any] = Field(default=None, description="分组ID")
client_id: str = Field(..., description="终端唯一ID")
client_name: str = Field(default="", description="终端名称")
computer_name: str = Field(default="", description="计算机名")
local_ip: str = Field(default="", description="本地IP")
connect_ip: str = Field(default="", description="连接IP")
mac: str = Field(default="", description="MAC地址")
count: int = Field(default=0, description="病毒日志总数")
result: Optional[VirusHandleResult] = Field(default=None, description="处理结果统计")
# ==========================================================================
# 终端任务 — /api/task/_create
# ==========================================================================
class TaskCreateRequest(BaseModel):
"""终端任务创建请求。
支持的任务类型:
- quick_scan: 快速扫描
- full_scan: 全盘扫描
- custom_scan: 自定义扫描
- netctrl: 终端隔离/解除
- message: 发送通知
Attributes:
task_type: 任务类型
client_ids: 目标终端ID列表
net_isolation: 是否隔离(仅 netctrl 类型有效)
message_content: 通知内容(仅 message 类型有效)
"""
task_type: str = Field(..., description="任务类型: quick_scan/full_scan/custom_scan/netctrl/message")
client_ids: List[str] = Field(..., min_length=1, description="目标终端ID列表")
net_isolation: Optional[bool] = Field(default=None, description="是否隔离(仅netctrl类型)")
message_content: Optional[str] = Field(default=None, description="通知内容(仅message类型)")
# ==========================================================================
# 终端安全画像(聚合模型,供前端直接使用)
# ==========================================================================
class TerminalSecurityProfile(BaseModel):
"""终端安全画像(聚合模型)。
将终端基本信息+安全状态聚合成一个模型,供坐席端直接展示。
Attributes:
client_id: 终端唯一ID
computer_name: 计算机名
ip: 本地IP
mac: MAC地址
os_version: 操作系统版本
is_online: 在线状态
group_name: 分组名称
hardware: 硬件概要
high_risk_leaks: 高危漏洞数
uncleaned_virus: 未处理病毒事件数
security_score: 安全评分(0-100,综合漏洞+病毒+在线状态)
"""
client_id: str = Field(..., description="终端唯一ID")
computer_name: str = Field(default="", description="计算机名")
ip: str = Field(default="", description="本地IP")
mac: str = Field(default="", description="MAC地址")
os_version: str = Field(default="", description="操作系统版本")
is_online: bool = Field(default=False, description="在线状态")
group_name: str = Field(default="", description="分组名称")
hardware: Optional[HardwareInfo] = Field(default=None, description="硬件概要")
high_risk_leaks: int = Field(default=0, description="高危漏洞数")
uncleaned_virus: int = Field(default=0, description="未处理病毒事件数")
security_score: int = Field(default=100, description="安全评分(0-100)")
@@ -0,0 +1,15 @@
# 联软LV7000 API集成模块
"""
提供联软LV7000终端安全管理系统的API客户端。
认证方式:三层认证(IP白名单 + 账号密码 + Token
- 第一层:IP白名单(在联软后台配置WhiteListServerIp
- 第二层:账号密码(ApiAccount + ApiPassword
- 第三层:一次性Token(先调getToken获取,30分钟有效)
核心P0接口:
- queryDevByParams:按条件查询终端(含strusername员工账号映射)
- getDevAllInfo:终端详细信息(硬件+软件+资产+网络)
- getUserInfoByAccount:按账号查用户信息
- getAllOrgInfo:全量组织架构同步
"""
+604
View File
@@ -0,0 +1,604 @@
# 联软LV7000 API客户端
"""
联软LV7000终端安全管理系统 API 客户端。
认证流程:
1. 第一层:IP白名单(在联软后台配置,调用时自动生效)
2. 第二层:账号密码(ApiAccount + ApiPassword
3. 第三层:Token(先调getToken获取,30分钟有效,自动缓存+刷新)
接口调用方式:
- GET请求:参数通过query string传递
- POST请求:参数通过form-data传递
- 统一携带 token + apiAccount + apiPassword + validatekey
使用示例:
client = LianruanClient(base_url, api_account, api_password, validate_key)
terminals = await client.query_dev_by_params(strusername="songxian")
detail = await client.get_dev_all_info(strdevname="IT-SONGXIAN")
"""
import time
import logging
from typing import Optional
import httpx
from app.integrations.lianruan.exceptions import (
LianruanApiError,
LianruanAuthError,
LianruanConnectionError,
)
from app.integrations.lianruan.models import (
TerminalBasicInfo,
TerminalAllInfo,
UserInfo,
OrgDeptInfo,
OnlineStatus,
TerminalSoftwareInfo,
)
logger = logging.getLogger(__name__)
class LianruanClient:
"""联软LV7000 API客户端。
Attributes:
base_url: 联软API地址,如 http://192.168.x.x:30098
api_account: API账号
api_password: API密码
validate_key: 验证密钥
_token: 缓存的Token
_token_expire: Token过期时间戳
"""
def __init__(
self,
base_url: str,
api_account: str,
api_password: str,
validate_key: str = "",
timeout: float = 30.0,
):
self.base_url = base_url.rstrip("/")
self.api_account = api_account
self.api_password = api_password
self.validate_key = validate_key
self.timeout = timeout
# Token缓存(30分钟有效,提前5分钟刷新)
self._token: str = ""
self._token_expire: float = 0.0
# httpx异步客户端(连接池复用)
self._client: Optional[httpx.AsyncClient] = None
async def _get_client(self) -> httpx.AsyncClient:
"""获取或创建httpx异步客户端(懒初始化+连接池复用)"""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
timeout=self.timeout,
verify=False, # 内网自签证书
)
return self._client
async def close(self) -> None:
"""关闭httpx客户端,释放连接池"""
if self._client and not self._client.is_closed:
await self._client.aclose()
# ==========================================================================
# Token管理
# ==========================================================================
async def _ensure_token(self) -> str:
"""确保Token有效,过期则自动刷新。
联软Token默认30分钟有效,提前5分钟刷新。
Returns:
str: 有效的Token字符串
"""
now = time.time()
# Token还有5分钟以上有效期,直接复用
if self._token and now < self._token_expire - 300:
return self._token
# 重新获取Token
logger.info("联软Token过期或为空,正在刷新...")
try:
client = await self._get_client()
url = f"{self.base_url}/token"
params = {
"act": "getToken",
"apiAccount": self.api_account,
"apiPassword": self.api_password,
}
if self.validate_key:
params["validatekey"] = self.validate_key
resp = await client.get(url, params=params)
resp.raise_for_status()
data = resp.json()
if data.get("status") != "SUCCESS":
raise LianruanAuthError(
f"获取Token失败: {data.get('msg', '未知错误')}",
detail=str(data),
)
self._token = data.get("data", data.get("rows", ""))
if not self._token:
# 有些版本返回格式不同
self._token = str(data.get("token", ""))
# 30分钟有效期
self._token_expire = now + 1800
logger.info("联软Token刷新成功,有效期至 %s",
time.strftime("%H:%M:%S", time.localtime(self._token_expire)))
return self._token
except httpx.ConnectError as e:
raise LianruanConnectionError(
f"无法连接联软服务器 {self.base_url}: {e}",
detail=str(e),
)
except httpx.TimeoutException as e:
raise LianruanConnectionError(
f"连接联软服务器超时: {e}",
detail=str(e),
)
# ==========================================================================
# 通用请求方法
# ==========================================================================
async def _request(
self,
path: str,
act: str,
params: Optional[dict] = None,
method: str = "GET",
) -> dict:
"""发送请求到联软API。
自动携带认证参数(token + apiAccount + apiPassword)。
Args:
path: API路径,如 /terminal 或 /querydeptuser
act: 操作类型,如 queryDevByParams
params: 额外业务参数
method: 请求方法(GET/POST
Returns:
dict: 联软API返回的JSON数据
Raises:
LianruanAuthError: 认证失败
LianruanApiError: 业务错误
LianruanConnectionError: 网络错误
"""
token = await self._ensure_token()
client = await self._get_client()
# 构建完整参数:认证参数 + 业务参数
full_params = {
"act": act,
"apiAccount": self.api_account,
"apiPassword": self.api_password,
"token": token,
}
if self.validate_key:
full_params["validatekey"] = self.validate_key
if params:
full_params.update(params)
url = f"{self.base_url}{path}"
try:
if method.upper() == "POST":
resp = await client.post(url, data=full_params)
else:
resp = await client.get(url, params=full_params)
resp.raise_for_status()
data = resp.json()
except httpx.ConnectError as e:
raise LianruanConnectionError(
f"无法连接联软服务器: {e}",
detail=str(e),
)
except httpx.TimeoutException as e:
raise LianruanConnectionError(
f"请求联软超时: {e}",
detail=str(e),
)
except httpx.HTTPStatusError as e:
raise LianruanApiError(
f"联软HTTP错误 {e.response.status_code}",
status=str(e.response.status_code),
detail=str(e),
)
# 检查联软业务状态码
status = data.get("status", "")
if status == "INVALID":
# Token可能过期,清除缓存重试一次
self._token = ""
self._token_expire = 0
raise LianruanAuthError(
f"联软认证失败(IP不在白名单或Token无效): {data.get('msg', '')}",
detail=str(data),
)
elif status == "ERROR":
raise LianruanApiError(
f"联软API错误: {data.get('msg', '')}",
status=status,
detail=str(data),
)
elif status == "Exceed":
raise LianruanApiError(
f"联软数据量超限: {data.get('msg', '')}",
status=status,
detail=str(data),
)
elif status != "SUCCESS":
raise LianruanApiError(
f"联软未知状态: {status} - {data.get('msg', '')}",
status=status,
detail=str(data),
)
return data
# ==========================================================================
# P0接口 — 终端设备查询
# ==========================================================================
async def query_dev_by_params(
self,
strusername: str = "",
strdevname: str = "",
strdevip: str = "",
strmac: str = "",
page: int = 1,
per_page: int = 20,
) -> dict:
"""查询终端设备(核心映射接口)。
⭐ strusername 参数可直接按员工账号查终端,这是联软最大的优势!
Args:
strusername: 员工账号(映射金钥匙)
strdevname: 计算机名
strdevip: IP地址
strmac: MAC地址
page: 页码(从1开始)
per_page: 每页条数
Returns:
dict: {"items": [TerminalBasicInfo], "total": int}
"""
params: dict = {}
if strusername:
params["strusername"] = strusername
if strdevname:
params["strdevname"] = strdevname
if strdevip:
params["strdevip"] = strdevip
if strmac:
params["strmac"] = strmac
# 联软分页参数
params["page"] = str(page)
params["rows"] = str(per_page)
data = await self._request("/terminal", "queryDevByParams", params)
rows = data.get("rows", [])
total = data.get("total", len(rows))
items = [TerminalBasicInfo(**row) for row in rows]
return {"items": items, "total": total}
async def get_dev_all_info(
self,
strdevname: str = "",
strdevip: str = "",
) -> TerminalAllInfo:
"""查询终端详细信息(极详细硬件+软件+资产+网络)。
比火绒_info2更丰富,包含逻辑磁盘使用率、显示器信息、内存条详情。
Args:
strdevname: 计算机名(二选一)
strdevip: IP地址(二选一)
Returns:
TerminalAllInfo: 终端详细信息
"""
params: dict = {}
if strdevname:
params["strdevname"] = strdevname
if strdevip:
params["strdevip"] = strdevip
data = await self._request(
"/devallinfoshowwithpaging", "getDevAllInfo", params
)
# 返回格式:data.equipment + data.equipmentdetail
equipment = data.get("equipment", data.get("rows", [{}]))
if isinstance(equipment, list) and equipment:
equipment = equipment[0]
equipment_detail = data.get("equipmentdetail", {})
if isinstance(equipment_detail, list) and equipment_detail:
equipment_detail = equipment_detail[0]
dev_detail = equipment_detail.get("devdetail", equipment_detail)
# 解析硬件详情
result = TerminalAllInfo(
strdevname=equipment.get("strdevname", ""),
strip1=equipment.get("strip1", ""),
strmac=equipment.get("strmac", ""),
strdeptname=equipment.get("strdeptname", ""),
strusername=equipment.get("strusername", ""),
struserdes=equipment.get("struserdes", ""),
stros=equipment.get("stros", ""),
strdomain=equipment.get("strdomain", ""),
istatus=dev_detail.get("istatus", equipment.get("istatus", "")),
strverofuaagent=dev_detail.get("strverofuaagent", ""),
devassetno=dev_detail.get("devassetno", ""),
devgroup=dev_detail.get("devgroup", ""),
)
# 解析硬件列表
self._parse_hardware_list(dev_detail, "CPUInformation", result.cpu)
self._parse_hardware_list(dev_detail, "MemoryInformation", result.memory)
self._parse_hardware_list(dev_detail, "HardDiskInformation", result.hard_disk)
self._parse_hardware_list(dev_detail, "GraphicsCardInformation", result.graphics_card)
self._parse_hardware_list(dev_detail, "MainboardInformation", result.mainboard)
# 解析逻辑磁盘(含使用率)
for ld in dev_detail.get("LogicalDiskInformation", []):
result.logical_disk.append(LogicalDiskInfo(
name=ld.get("strlogicaldiskname", ""),
file_system=ld.get("strfilesystem", ""),
total_size=ld.get("strtotalsize", ""),
free_space=ld.get("strfreespace", ""),
usage_percent=ld.get("strusagepercent", ""),
))
# 解析网卡
for nc in dev_detail.get("NetworkCardInformation", []):
result.network_card.append(NetworkCardInfo(
name=nc.get("strnetcardname", ""),
is_wireless=nc.get("iswireless", ""),
vendor=nc.get("strnetcardvendor", ""),
mac=nc.get("strnetcardmac", ""),
))
# 解析显示器
for d in dev_detail.get("DisplayInformation", []):
result.display.append(DisplayInfo(
vendor=d.get("strdisplayvendor", ""),
model=d.get("strdisplaymodel", ""),
serial=d.get("strdisplayserial", ""),
size=d.get("strdisplaysize", ""),
))
return result
def _parse_hardware_list(
self, dev_detail: dict, key: str, target_list: list
) -> None:
"""解析硬件信息列表(CPU/内存/硬盘等)"""
from app.integrations.lianruan.models import HardwareInfo
for item in dev_detail.get(key, []):
target_list.append(HardwareInfo(
name=item.get("strcpuname", item.get("strmemname", item.get("strdiskname", ""))),
model=item.get("strcpumodel", item.get("strmemmodel", item.get("strdiskmodel", ""))),
vendor=item.get("strcpuvendor", item.get("strmemvendor", item.get("strdiskvendor", ""))),
capacity=item.get("strcpufrequency", item.get("strmemcapacity", item.get("strdiskcapacity", ""))),
serial=item.get("strcpuserial", item.get("strmemserial", item.get("strdiskserial", ""))),
))
# ==========================================================================
# P0接口 — 组织架构/用户
# ==========================================================================
async def get_user_info_by_account(self, useraccount: str) -> Optional[UserInfo]:
"""按账号查询用户信息。
Args:
useraccount: 用户账号
Returns:
UserInfo或None
"""
data = await self._request(
"/querydeptuser",
"getUserInfoByAccount",
{"useraccount": useraccount},
)
rows = data.get("rows", data.get("row", []))
if rows:
row = rows[0] if isinstance(rows, list) else rows
return UserInfo(**row)
return None
async def get_all_org_info(self) -> list[OrgDeptInfo]:
"""获取全量组织架构(部门+用户)。
用于定时同步,构建组织架构映射。
Returns:
list[OrgDeptInfo]: 部门列表,每个部门含用户列表
"""
data = await self._request("/querydeptuser", "getAllOrgInfo")
rows = data.get("rows", [])
result = []
for dept_data in rows:
users = []
for u in dept_data.get("users", []):
users.append(UserInfo(**u))
result.append(OrgDeptInfo(
deptid=dept_data.get("deptid", ""),
deptname=dept_data.get("deptname", ""),
parentid=dept_data.get("parentid", ""),
users=users,
))
return result
# ==========================================================================
# P1接口 — 准入控制
# ==========================================================================
async def exist_online_user(
self, username: str, strdevip: str = ""
) -> OnlineStatus:
"""查询终端用户是否在线。
可精确判断某员工在某IP是否当前在线。
Args:
username: 用户名
strdevip: IP地址(可选)
Returns:
OnlineStatus: 在线状态
"""
params = {"username": username}
if strdevip:
params["strdevip"] = strdevip
data = await self._request(
"/access/onlineUser", "existOnlineUser", params
)
is_online = data.get("data", "0") == "1"
return OnlineStatus(
username=username,
ip=strdevip,
is_online=is_online,
)
# ==========================================================================
# P1接口 — 终端操作
# ==========================================================================
async def notice_agent_msg(
self, strdevip: str, message: str
) -> bool:
"""向终端推送弹窗消息。
Args:
strdevip: 终端IP
message: 消息内容
Returns:
bool: 是否成功
"""
data = await self._request(
"/terminal",
"noticeAgentMsg",
{"strdevip": strdevip, "msg": message},
)
return data.get("status") == "SUCCESS"
async def remote_wake_up(
self, strdevip: str, strmac: str
) -> bool:
"""远程唤醒终端。
Args:
strdevip: 终端IP
strmac: 终端MAC地址
Returns:
bool: 是否成功
"""
data = await self._request(
"/terminal",
"remoteWakeUp",
{"strdevip": strdevip, "strmac": strmac},
)
return data.get("status") == "SUCCESS"
async def query_software_by_dev(
self, strdevname: str = "", strdevip: str = ""
) -> Optional[TerminalSoftwareInfo]:
"""查询终端安装软件。
Args:
strdevname: 计算机名
strdevip: IP地址
Returns:
TerminalSoftwareInfo或None
"""
params: dict = {}
if strdevname:
params["strdevname"] = strdevname
if strdevip:
params["strdevip"] = strdevip
data = await self._request("/software", "querysoftwarebydev", params)
rows = data.get("rows", [])
if not rows:
return None
row = rows[0] if isinstance(rows, list) else rows
softwares = []
for s in row.get("softwares", []):
softwares.append(SoftwareInfo(
name=s.get("strsoftware", ""),
version=s.get("strversion", ""),
vendor=s.get("strvendor", ""),
install_date=s.get("installdate", ""),
))
return TerminalSoftwareInfo(
strdevname=row.get("strdevname", ""),
strdevip=row.get("strdevip", ""),
strmac=row.get("strmac", ""),
strusername=row.get("strusername", ""),
softwares=softwares,
)
# ==========================================================================
# 测试连接
# ==========================================================================
async def test_connection(self) -> dict:
"""测试联软API连接。
使用getToken接口验证:
1. 网络连通性
2. IP白名单
3. 账号密码正确性
4. Token获取成功
Returns:
dict: {"success": bool, "message": str}
"""
try:
token = await self._ensure_token()
if token:
return {
"success": True,
"message": "联软API连接成功,Token获取正常",
}
else:
return {
"success": False,
"message": "Token获取失败,返回为空",
}
except LianruanAuthError as e:
return {"success": False, "message": e.message}
except LianruanConnectionError as e:
return {"success": False, "message": e.message}
except Exception as e:
return {"success": False, "message": f"未知错误: {str(e)}"}
@@ -0,0 +1,98 @@
# 联软LV7000配置管理
"""
从system_configs表读取联软API配置,构建LianruanClient实例。
联软配置键(前缀 integration_lianruan_):
- integration_lianruan_base_url: 联软API地址(如 http://192.168.x.x:30098
- integration_lianruan_api_account: API账号
- integration_lianruan_api_password: API密码
- integration_lianruan_validate_key: 验证密钥(可选)
配置方式:管理后台 → 系统集成 → 联软LV7000 → 填入账号密码
"""
import logging
from sqlalchemy.ext.asyncio import AsyncSession
from app.integrations.lianruan.client import LianruanClient
from app.integrations.lianruan.exceptions import LianruanConfigError
from app.models.system_config import SystemConfig
logger = logging.getLogger(__name__)
# 联软配置键前缀(与 admin_service INTEGRATION_DEFINITIONS 中的 key_prefix 一致)
_PREFIX = "integration_lianruan_"
async def _get_lianruan_config_value(db: AsyncSession, key_suffix: str) -> str:
"""读取单个联软配置值。
Args:
db: 数据库会话
key_suffix: 配置键后缀(如 base_url / api_account
Returns:
str: 配置值,不存在返回空字符串
"""
full_key = f"{_PREFIX}{key_suffix}"
from sqlalchemy import select
result = await db.execute(select(SystemConfig).where(SystemConfig.key == full_key))
config_row = result.scalar_one_or_none()
return config_row.value if config_row else ""
async def get_lianruan_config(db: AsyncSession) -> dict:
"""从system_configs表读取联软配置。
Args:
db: 数据库会话
Returns:
dict: 包含 base_url / api_account / api_password / validate_key
Raises:
LianruanConfigError: 配置缺失
"""
base_url = await _get_lianruan_config_value(db, "base_url")
api_account = await _get_lianruan_config_value(db, "api_account")
api_password = await _get_lianruan_config_value(db, "api_password")
validate_key = await _get_lianruan_config_value(db, "validate_key")
if not base_url:
raise LianruanConfigError("联软API未配置:缺少Base URL")
if not api_account:
raise LianruanConfigError("联软API未配置:缺少API账号")
if not api_password:
raise LianruanConfigError("联软API未配置:缺少API密码")
return {
"base_url": base_url,
"api_account": api_account,
"api_password": api_password,
"validate_key": validate_key,
}
async def get_lianruan_client(db: AsyncSession) -> LianruanClient:
"""构建联软API客户端实例。
从system_configs表读取配置,创建LianruanClient。
Args:
db: 数据库会话
Returns:
LianruanClient: 已配置的联软客户端
Raises:
LianruanConfigError: 配置缺失
"""
cfg = await get_lianruan_config(db)
return LianruanClient(
base_url=cfg["base_url"],
api_account=cfg["api_account"],
api_password=cfg["api_password"],
validate_key=cfg.get("validate_key", ""),
)
@@ -0,0 +1,61 @@
# 联软LV7000异常体系
"""
定义联软API集成的异常类层级。
层级:
LianruanError — 基类(所有联软异常)
├── LianruanConfigError — 配置缺失(未填写账号/密码/BaseURL)
├── LianruanAuthError — 认证失败(IP不在白名单/账号密码错误/Token过期)
├── LianruanConnectionError — 网络连接失败(超时/拒绝连接)
└── LianruanApiError — API业务错误(参数错误/数据超限/其他)
"""
class LianruanError(Exception):
"""联软异常基类"""
def __init__(self, message: str, detail: str = ""):
self.message = message
self.detail = detail
super().__init__(message)
class LianruanConfigError(LianruanError):
"""配置缺失异常。
场景:未配置联软 BaseURL / ApiAccount / ApiPassword
"""
pass
class LianruanAuthError(LianruanError):
"""认证失败异常。
场景:
- IP不在白名单(status=INVALID
- 账号密码错误
- Token过期(需重新获取)
"""
pass
class LianruanConnectionError(LianruanError):
"""网络连接失败异常。
场景:超时/拒绝连接/DNS解析失败
"""
pass
class LianruanApiError(LianruanError):
"""API业务错误异常。
场景:
- 参数错误(status=ERROR
- 数据量超限(status=Exceed
- 其他业务异常
"""
def __init__(self, message: str, status: str = "", detail: str = ""):
self.status = status # 联软返回的status字段(ERROR/Exceed等)
super().__init__(message, detail)
+193
View File
@@ -0,0 +1,193 @@
# 联软LV7000数据模型
"""
定义联软API返回数据的Pydantic模型。
核心模型:
- TerminalBasicInfo:终端基本信息(queryDevByParams返回)
- TerminalAllInfo:终端详细信息(getDevAllInfo返回,极详细)
- UserInfo:用户信息(getUserInfoByAccount返回)
- OrgInfo:组织架构信息(getAllOrgInfo返回)
- OnlineStatus:终端在线状态(existOnlineUser返回)
"""
from typing import Optional
from pydantic import BaseModel, Field
# ==========================================================================
# 终端基本信息(queryDevByParams返回)
# ==========================================================================
class TerminalBasicInfo(BaseModel):
"""终端基本信息 — 最核心的映射数据源。
⭐ strusername + struserdes 字段直接提供员工账号→终端映射!
这是联软相比火绒最大的优势。
"""
# 终端标识
strdevname: str = Field(default="", description="计算机名")
strdevip: str = Field(default="", description="IP地址")
strmac: str = Field(default="", description="MAC地址")
# ⭐ 员工映射字段(核心价值)
strusername: str = Field(default="", description="使用该终端的用户账号(映射金钥匙)")
struserdes: str = Field(default="", description="用户姓名/描述")
# 组织信息
strdeptname: str = Field(default="", description="所属部门名")
# 状态
istatus: str = Field(default="", description="终端状态(1=在线/0=离线)")
# 网络
strswitchname: str = Field(default="", description="接入交换机名")
strifname: str = Field(default="", description="交换机接口名")
# 联系方式
strmail: str = Field(default="", description="用户邮箱")
strphone: str = Field(default="", description="用户电话")
# 其他
strdomain: str = Field(default="", description="Windows域")
strdevtype: str = Field(default="", description="设备类型")
# ==========================================================================
# 终端详细信息(getDevAllInfo返回)
# ==========================================================================
class HardwareInfo(BaseModel):
"""硬件组件信息"""
name: str = Field(default="", description="名称")
model: str = Field(default="", description="型号")
vendor: str = Field(default="", description="厂商")
capacity: str = Field(default="", description="容量")
serial: str = Field(default="", description="序列号")
class LogicalDiskInfo(BaseModel):
"""逻辑磁盘信息(含使用率,判断磁盘满)"""
name: str = Field(default="", description="卷标")
file_system: str = Field(default="", description="文件系统")
total_size: str = Field(default="", description="总量")
free_space: str = Field(default="", description="可用空间")
usage_percent: str = Field(default="", description="使用率")
class NetworkCardInfo(BaseModel):
"""网卡信息"""
name: str = Field(default="", description="名称")
is_wireless: str = Field(default="", description="是否无线")
vendor: str = Field(default="", description="厂商")
mac: str = Field(default="", description="MAC地址")
class DisplayInfo(BaseModel):
"""显示器信息(多屏配置排查)"""
vendor: str = Field(default="", description="厂商")
model: str = Field(default="", description="型号")
serial: str = Field(default="", description="序列号")
size: str = Field(default="", description="尺寸")
class TerminalAllInfo(BaseModel):
"""终端详细信息 — 极其详细,比火绒_info2更丰富。
包含:设备基础+硬件+软件+资产+网络配置。
特别是逻辑磁盘使用率和显示器信息,是火绒没有的。
"""
# 设备基础
strdevname: str = Field(default="", description="计算机名")
strip1: str = Field(default="", description="IP地址")
strmac: str = Field(default="", description="MAC地址")
strnatip: str = Field(default="", description="NAT IP")
macverdor: str = Field(default="", description="MAC厂商")
strdevtype: str = Field(default="", description="设备类型")
# 组织+用户
strdeptname: str = Field(default="", description="所属部门")
strusername: str = Field(default="", description="用户账号⭐")
struserdes: str = Field(default="", description="用户姓名⭐")
# 时间
dtdevuptime: str = Field(default="", description="最近上线时间")
dtdevdowntime: str = Field(default="", description="最近下线时间")
dtdevfirstfoundtime: str = Field(default="", description="首次发现时间")
# 系统
stros: str = Field(default="", description="操作系统")
strdomain: str = Field(default="", description="Windows域")
strserialnumber: str = Field(default="", description="序列号")
strmainboardtype: str = Field(default="", description="主板型号")
# 客户端详情
strverofuaagent: str = Field(default="", description="安全助手版本")
istatus: str = Field(default="", description="在线状态")
devassetno: str = Field(default="", description="设备资产号")
devgroup: str = Field(default="", description="设备所属设备组")
# 硬件详情(列表)
mainboard: list[HardwareInfo] = Field(default_factory=list, description="主板信息")
cpu: list[HardwareInfo] = Field(default_factory=list, description="CPU信息")
memory: list[HardwareInfo] = Field(default_factory=list, description="内存信息")
hard_disk: list[HardwareInfo] = Field(default_factory=list, description="硬盘信息")
logical_disk: list[LogicalDiskInfo] = Field(default_factory=list, description="逻辑磁盘")
graphics_card: list[HardwareInfo] = Field(default_factory=list, description="显卡信息")
network_card: list[NetworkCardInfo] = Field(default_factory=list, description="网卡信息")
display: list[DisplayInfo] = Field(default_factory=list, description="显示器信息")
# ==========================================================================
# 用户信息(getUserInfoByAccount返回)
# ==========================================================================
class UserInfo(BaseModel):
"""用户信息"""
deptid: str = Field(default="", description="部门ID")
userid: str = Field(default="", description="用户ID")
useraccount: str = Field(default="", description="用户账号")
username: str = Field(default="", description="用户姓名")
# ==========================================================================
# 组织架构信息(getAllOrgInfo返回)
# ==========================================================================
class OrgDeptInfo(BaseModel):
"""部门信息"""
deptid: str = Field(default="", description="部门ID")
deptname: str = Field(default="", description="部门名称")
parentid: str = Field(default="", description="父部门ID")
users: list[UserInfo] = Field(default_factory=list, description="部门下用户列表")
# ==========================================================================
# 终端在线状态(existOnlineUser返回)
# ==========================================================================
class OnlineStatus(BaseModel):
"""终端在线状态"""
username: str = Field(default="", description="用户名")
ip: str = Field(default="", description="IP地址")
is_online: bool = Field(default=False, description="是否在线")
# ==========================================================================
# 软件信息(querysoftwarebydev返回)
# ==========================================================================
class SoftwareInfo(BaseModel):
"""软件安装信息"""
name: str = Field(default="", description="软件名称")
version: str = Field(default="", description="版本")
vendor: str = Field(default="", description="厂商")
install_date: str = Field(default="", description="安装日期")
class TerminalSoftwareInfo(BaseModel):
"""终端安装软件信息"""
strdevname: str = Field(default="", description="计算机名")
strdevip: str = Field(default="", description="IP地址")
strmac: str = Field(default="", description="MAC地址")
strusername: str = Field(default="", description="用户账号")
softwares: list[SoftwareInfo] = Field(default_factory=list, description="软件列表")
@@ -0,0 +1,35 @@
# =============================================================================
# RAGFlow 集成模块
# =============================================================================
from .client import RagflowClient
from .config import get_ragflow_client
from .exceptions import (
RagflowApiError,
RagflowAuthError,
RagflowConfigError,
RagflowConnectionError,
RagflowError,
)
from .models import (
DatasetInfo,
DocAggregate,
DocumentInfo,
RetrievalChunk,
RetrievalResult,
)
__all__ = [
"RagflowClient",
"get_ragflow_client",
"RagflowError",
"RagflowConfigError",
"RagflowAuthError",
"RagflowApiError",
"RagflowConnectionError",
"RetrievalChunk",
"DocAggregate",
"RetrievalResult",
"DatasetInfo",
"DocumentInfo",
]
+449
View File
@@ -0,0 +1,449 @@
# =============================================================================
# RAGFlow API 客户端
# =============================================================================
# 说明:封装 RAGFlow 知识检索引擎的 API 调用
# 核心功能:
# 1. 知识检索 — POST /api/v1/retrieval(核心接口)
# 2. 数据集管理 — 列出/创建/删除知识库
# 3. 文档管理 — 上传/列出/删除文档
# 4. 测试连接 — 验证 API Key 是否有效
# 认证方式:Authorization: Bearer <API_KEY>
# 参考文档:https://ragflow.io/docs/http_api_reference
# =============================================================================
import logging
from typing import Any, Dict, List, Optional
import httpx
from .exceptions import (
RagflowApiError,
RagflowAuthError,
RagflowConfigError,
RagflowConnectionError,
RagflowError,
)
from .models import (
DatasetInfo,
DocAggregate,
DocumentInfo,
RetrievalChunk,
RetrievalResult,
)
logger = logging.getLogger(__name__)
# 默认请求超时(秒)
DEFAULT_TIMEOUT = 30.0
# 默认分页大小
DEFAULT_PAGE_SIZE = 20
class RagflowClient:
"""RAGFlow API 客户端。
封装 RAGFlow 知识检索引擎的 API 调用,支持:
- 知识检索(核心功能)
- 数据集(知识库)管理
- 文档管理
- 连接测试
使用方式:
client = RagflowClient(
api_key="sk-xxx",
base_url="http://10.80.0.85:9380"
)
result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"])
"""
def __init__(
self,
api_key: str,
base_url: str = "http://10.80.0.85:9380",
timeout: float = DEFAULT_TIMEOUT,
):
"""初始化 RAGFlow 客户端。
Args:
api_key: RAGFlow API KeyBearer Token
base_url: RAGFlow API 基础地址(不含尾部斜杠)
timeout: 默认请求超时(秒)
Raises:
RagflowConfigError: API Key 为空
"""
if not api_key:
raise RagflowConfigError("RAGFlow API Key 不能为空")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.timeout = timeout
def _headers(self) -> Dict[str, str]:
"""构建请求头。
Returns:
Dict: 包含 Authorization 和 Content-Type 的请求头
"""
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def _request(
self,
method: str,
path: str,
json_data: Optional[Dict] = None,
params: Optional[Dict] = None,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
"""统一请求封装。
Args:
method: HTTP 方法(GET/POST/PUT/DELETE
path: API 路径(如 /api/v1/retrieval
json_data: JSON 请求体
params: 查询参数
timeout: 覆盖默认超时
Returns:
Dict: API 响应的 JSON 数据
Raises:
RagflowAuthError: 认证失败(401
RagflowApiError: API 返回错误
RagflowConnectionError: 网络连接失败
"""
url = f"{self.base_url}{path}"
req_timeout = timeout or self.timeout
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method=method,
url=url,
headers=self._headers(),
json=json_data,
params=params,
timeout=req_timeout,
)
# 处理 HTTP 错误
if response.status_code == 401:
raise RagflowAuthError("RAGFlow API Key 无效或已过期")
if response.status_code >= 400:
try:
err_body = response.json()
err_msg = err_body.get("message", response.text)
except Exception:
err_msg = response.text
raise RagflowApiError(
code=response.status_code,
message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}",
)
# 解析响应
result = response.json()
# RAGFlow 统一响应格式:{code: 0, data: ..., message: ...}
if result.get("code") != 0:
raise RagflowApiError(
code=result.get("code", -1),
message=result.get("message", "未知错误"),
)
return result
except httpx.TimeoutException:
raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}")
except httpx.ConnectError:
raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}")
except (RagflowAuthError, RagflowApiError, RagflowConnectionError):
raise
except Exception as e:
raise RagflowError(f"RAGFlow 请求异常: {str(e)}")
# ==========================================================================
# 测试连接
# ==========================================================================
async def test_connection(self) -> Dict[str, Any]:
"""测试 RAGFlow API 连接。
通过列出数据集(limit=1)验证 API Key 是否有效。
Returns:
Dict: {success: bool, message: str}
"""
try:
result = await self.list_datasets(page=1, page_size=1)
return {
"success": True,
"message": f"连接成功,共 {result.get('total', 0)} 个知识库",
}
except RagflowAuthError:
return {"success": False, "message": "API Key 无效或已过期"}
except RagflowConnectionError as e:
return {"success": False, "message": f"连接失败: {e.message}"}
except RagflowError as e:
return {"success": False, "message": e.message}
# ==========================================================================
# 知识检索(核心接口)
# ==========================================================================
async def retrieval(
self,
question: str,
dataset_ids: Optional[List[str]] = None,
document_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.2,
vector_similarity_weight: float = 0.3,
top_k: int = 1024,
keyword: bool = False,
highlight: bool = False,
) -> RetrievalResult:
"""知识检索 — 从知识库中搜索相关文档片段。
这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。
Args:
question: 用户查询问题
dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一)
document_ids: 要搜索的文档ID列表
similarity_threshold: 最小相似度阈值(0-1),默认 0.2
vector_similarity_weight: 向量相似度权重(0-1),默认 0.3
top_k: 参与计算的块数量,默认 1024
keyword: 是否启用关键词匹配,默认 False
highlight: 是否高亮匹配术语,默认 False
Returns:
RetrievalResult: 检索结果(含文本块、文档聚合、总数)
Raises:
RagflowError: 检索失败
"""
body: Dict[str, Any] = {
"question": question,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"keyword": keyword,
"highlight": highlight,
}
if dataset_ids:
body["dataset_ids"] = dataset_ids
if document_ids:
body["document_ids"] = document_ids
result = await self._request("POST", "/api/v1/retrieval", json_data=body)
data = result.get("data", {})
# 解析文本块
chunks = [
RetrievalChunk.model_validate(chunk)
for chunk in data.get("chunks", [])
]
# 解析文档聚合
doc_aggs = [
DocAggregate.model_validate(agg)
for agg in data.get("doc_aggs", [])
]
return RetrievalResult(
chunks=chunks,
doc_aggs=doc_aggs,
total=data.get("total", 0),
)
# ==========================================================================
# 数据集(知识库)管理
# ==========================================================================
async def list_datasets(
self,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""列出所有数据集(知识库)。
Args:
page: 页码
page_size: 每页条数
Returns:
Dict: {items: List[DatasetInfo], total: int}
"""
result = await self._request(
"GET",
"/api/v1/datasets",
params={"page": page, "page_size": page_size},
)
data = result.get("data", {})
items = [
DatasetInfo.model_validate(ds)
for ds in data.get("datasets", [])
]
return {"items": items, "total": data.get("total", 0)}
async def create_dataset(
self,
name: str,
embedding_model: str = "BAAI/bge-m3@BAAI",
chunk_method: str = "naive",
permission: str = "me",
) -> DatasetInfo:
"""创建数据集(知识库)。
Args:
name: 数据集名称
embedding_model: 向量模型
chunk_method: 分块方法(naive/qa/book/laws 等)
permission: 权限(me/team
Returns:
DatasetInfo: 创建的数据集信息
"""
body = {
"name": name,
"embedding_model": embedding_model,
"chunk_method": chunk_method,
"permission": permission,
}
result = await self._request("POST", "/api/v1/datasets", json_data=body)
return DatasetInfo.model_validate(result.get("data", {}))
async def delete_dataset(self, dataset_ids: List[str]) -> bool:
"""删除数据集。
Args:
dataset_ids: 要删除的数据集ID列表
Returns:
bool: 是否成功
"""
await self._request(
"DELETE",
"/api/v1/datasets",
json_data={"ids": dataset_ids},
)
return True
# ==========================================================================
# 文档管理
# ==========================================================================
async def list_documents(
self,
dataset_id: str,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""列出数据集中的文档。
Args:
dataset_id: 数据集ID
page: 页码
page_size: 每页条数
Returns:
Dict: {items: List[DocumentInfo], total: int}
"""
result = await self._request(
"GET",
f"/api/v1/datasets/{dataset_id}/documents",
params={"page": page, "page_size": page_size},
)
data = result.get("data", {})
items = [
DocumentInfo.model_validate(doc)
for doc in data.get("documents", [])
]
return {"items": items, "total": data.get("total", 0)}
async def upload_document(
self,
dataset_id: str,
file_path: str,
file_name: Optional[str] = None,
) -> DocumentInfo:
"""上传文档到数据集。
Args:
dataset_id: 数据集ID
file_path: 本地文件路径
file_name: 文件名(可选,默认取 file_path 的文件名)
Returns:
DocumentInfo: 上传的文档信息
"""
import os
if not os.path.exists(file_path):
raise RagflowError(f"文件不存在: {file_path}")
fname = file_name or os.path.basename(file_path)
url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents"
try:
async with httpx.AsyncClient() as client:
with open(file_path, "rb") as f:
response = await client.post(
url=url,
headers={"Authorization": f"Bearer {self.api_key}"},
files={"file": (fname, f)},
timeout=60.0,
)
if response.status_code == 401:
raise RagflowAuthError()
result = response.json()
if result.get("code") != 0:
raise RagflowApiError(
code=result.get("code", -1),
message=result.get("message", "上传失败"),
)
docs = result.get("data", {}).get("documents", [])
if docs:
return DocumentInfo.model_validate(docs[0])
return DocumentInfo(name=fname)
except (RagflowAuthError, RagflowApiError):
raise
except Exception as e:
raise RagflowError(f"文档上传失败: {str(e)}")
async def delete_documents(
self,
dataset_id: str,
document_ids: List[str],
) -> bool:
"""删除文档。
Args:
dataset_id: 数据集ID
document_ids: 要删除的文档ID列表
Returns:
bool: 是否成功
"""
await self._request(
"DELETE",
f"/api/v1/datasets/{dataset_id}/documents",
json_data={"ids": document_ids},
)
return True
@@ -0,0 +1,61 @@
# =============================================================================
# RAGFlow 配置加载器
# =============================================================================
# 说明:从数据库 system_configs 表加载 RAGFlow 配置,创建客户端实例
# 配置项:integration_ragflow_api_url + integration_ragflow_api_key
import logging
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.system_config import SystemConfig
from .client import RagflowClient
from .exceptions import RagflowConfigError
logger = logging.getLogger(__name__)
# 默认 RAGFlow API 地址(生产环境)
DEFAULT_RAGFLOW_BASE_URL = "http://10.80.0.85:9380"
async def _get_config(db: AsyncSession, key: str) -> str:
"""从数据库读取单个配置值。"""
result = await db.execute(
select(SystemConfig.config_value).where(SystemConfig.config_key == key)
)
row = result.scalar()
return row if row else ""
async def get_ragflow_client(db: AsyncSession) -> RagflowClient:
"""从数据库配置创建 RAGFlow 客户端实例。
读取 system_configs 表中的:
- integration_ragflow_api_url: RAGFlow API 地址
- integration_ragflow_api_key: RAGFlow API Key
Args:
db: 数据库会话
Returns:
RagflowClient: 客户端实例
Raises:
RagflowConfigError: 配置缺失
"""
api_url = await _get_config(db, "integration_ragflow_api_url")
api_key = await _get_config(db, "integration_ragflow_api_key")
# 如果数据库没有配置,使用默认地址
if not api_url:
api_url = DEFAULT_RAGFLOW_BASE_URL
if not api_key:
raise RagflowConfigError(
"RAGFlow API Key 未配置,请在管理后台 → 集成管理 → RAGFlow 中设置"
)
return RagflowClient(api_key=api_key, base_url=api_url)
@@ -0,0 +1,35 @@
# =============================================================================
# RAGFlow API 异常定义
# =============================================================================
class RagflowError(Exception):
"""RAGFlow 基础异常。"""
def __init__(self, message: str = "RAGFlow 错误"):
self.message = message
super().__init__(self.message)
class RagflowConfigError(RagflowError):
"""配置错误(缺少 API Key 或 Base URL)。"""
def __init__(self, message: str = "RAGFlow 配置缺失"):
super().__init__(message)
class RagflowAuthError(RagflowError):
"""认证失败(API Key 无效)。"""
def __init__(self, message: str = "RAGFlow 认证失败"):
super().__init__(message)
class RagflowApiError(RagflowError):
"""API 调用失败(非 200 响应)。"""
def __init__(self, code: int = 0, message: str = "RAGFlow API 错误"):
self.code = code
super().__init__(message)
class RagflowConnectionError(RagflowError):
"""网络连接失败。"""
def __init__(self, message: str = "RAGFlow 连接失败"):
super().__init__(message)
+110
View File
@@ -0,0 +1,110 @@
# =============================================================================
# RAGFlow API 数据模型
# =============================================================================
# 说明:定义 RAGFlow API 请求/响应的 Pydantic 数据模型
# 参考:https://ragflow.io/docs/http_api_reference
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class RetrievalChunk(BaseModel):
"""检索返回的单个文本块。
Attributes:
id: 块唯一ID
content: 块内容文本
document_id: 所属文档ID
document_keyword: 所属文档名称
similarity: 综合相似度分数
term_similarity: 关键词相似度
vector_similarity: 向量相似度
highlight: 高亮标记的内容(可选)
"""
id: str = Field(default="", description="块唯一ID")
content: str = Field(default="", description="块内容文本")
document_id: str = Field(default="", description="所属文档ID")
document_keyword: str = Field(default="", description="所属文档名称")
similarity: float = Field(default=0.0, description="综合相似度分数")
term_similarity: float = Field(default=0.0, description="关键词相似度")
vector_similarity: float = Field(default=0.0, description="向量相似度")
highlight: Optional[str] = Field(default=None, description="高亮标记的内容")
model_config = {"from_attributes": True}
class DocAggregate(BaseModel):
"""文档聚合统计。
Attributes:
doc_id: 文档ID
doc_name: 文档名称
count: 命中的块数量
"""
doc_id: str = Field(default="", description="文档ID")
doc_name: str = Field(default="", description="文档名称")
count: int = Field(default=0, description="命中块数量")
model_config = {"from_attributes": True}
class RetrievalResult(BaseModel):
"""检索结果。
Attributes:
chunks: 命中的文本块列表
doc_aggs: 按文档聚合统计
total: 命中总数
"""
chunks: List[RetrievalChunk] = Field(default_factory=list, description="命中文本块列表")
doc_aggs: List[DocAggregate] = Field(default_factory=list, description="文档聚合统计")
total: int = Field(default=0, description="命中总数")
model_config = {"from_attributes": True}
class DatasetInfo(BaseModel):
"""数据集(知识库)信息。
Attributes:
id: 数据集ID
name: 数据集名称
chunk_method: 分块方法
permission: 权限
document_count: 文档数量
embedding_model: 向量模型
create_time: 创建时间
update_time: 更新时间
"""
id: str = Field(default="", description="数据集ID")
name: str = Field(default="", description="数据集名称")
chunk_method: str = Field(default="naive", description="分块方法")
permission: str = Field(default="me", description="权限")
document_count: int = Field(default=0, description="文档数量")
embedding_model: str = Field(default="", description="向量模型")
create_time: Optional[str] = Field(default=None, description="创建时间")
update_time: Optional[str] = Field(default=None, description="更新时间")
model_config = {"from_attributes": True}
class DocumentInfo(BaseModel):
"""文档信息。
Attributes:
id: 文档ID
name: 文档名称
chunk_method: 分块方法
chunk_count: 块数量
create_time: 创建时间
update_time: 更新时间
"""
id: str = Field(default="", description="文档ID")
name: str = Field(default="", description="文档名称")
chunk_method: str = Field(default="naive", description="分块方法")
chunk_count: int = Field(default=0, description="块数量")
create_time: Optional[str] = Field(default=None, description="创建时间")
update_time: Optional[str] = Field(default=None, description="更新时间")
model_config = {"from_attributes": True}