chore: initial baseline with P0-safety .gitignore
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 外部系统集成模块包
|
||||
# =============================================================================
|
||||
# 说明:各外部系统的 API 客户端、数据模型、异常定义等
|
||||
# 当前已实现:火绒终端安全
|
||||
#
|
||||
@@ -0,0 +1,3 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 火绒终端安全集成模块包
|
||||
# =============================================================================
|
||||
@@ -0,0 +1,658 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 火绒终端安全 API 客户端
|
||||
# =============================================================================
|
||||
# 说明:封装火绒API的签名、请求、响应处理
|
||||
# 核心功能:
|
||||
# 1. HRESS 签名实现(Authorization Header 方式)
|
||||
# 2. 统一请求封装(超时、重试、异常处理)
|
||||
# 3. P0 接口:终端列表 _list / 终端详情 _info2 / 高危漏洞 _leak / 病毒事件 _virus_events
|
||||
# 4. P1 接口:终端隔离/解除 _create(netctrl) / 快速扫描 / 在线终端查询
|
||||
# 签名算法(来自火绒官方API文档 v1):
|
||||
# Authorization = "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature
|
||||
# Signature = urlencode(base64(hmac-sha1(AccessKeySecret,
|
||||
# AccessKeyId + "\n" + Expires + "\n" + HTTP-METHOD + "\n"
|
||||
# + Content-MD5 + "\n" + CanonicalizedResource)))
|
||||
# CanonicalizedResource = API路径(无前导/),含排序后的查询参数
|
||||
# Content-MD5 = base64(md5_digest(body_bytes)) — RFC2616
|
||||
# 使用方式:
|
||||
# client = HuorongClient(access_key_id="...", access_key_secret="...", base_url="...")
|
||||
# terminals = await client.list_terminals()
|
||||
# =============================================================================
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import (
|
||||
HuorongApiError,
|
||||
HuorongAuthError,
|
||||
HuorongConnectionError,
|
||||
HuorongError,
|
||||
)
|
||||
from .models import (
|
||||
HuorongApiResponse,
|
||||
TerminalBasicInfo,
|
||||
TerminalDetailV2,
|
||||
TerminalLeakInfo,
|
||||
VirusEventStats,
|
||||
VirusHandleResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认请求超时(秒)— 火绒内网响应通常在1秒内,3秒足够兜底
|
||||
# 注意:_virus_events 查询全部终端时可能较慢,需要更长超时
|
||||
DEFAULT_TIMEOUT = 10.0
|
||||
|
||||
# 默认分页大小
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
|
||||
# 签名有效期(秒)— 请求签名中的 Expires 字段
|
||||
SIGN_EXPIRES_SECONDS = 300
|
||||
|
||||
|
||||
class HuorongClient:
|
||||
"""火绒终端安全 API 客户端。
|
||||
|
||||
封装了火绒API的签名认证、请求发送和响应解析。
|
||||
所有方法均为异步(async),使用 httpx.AsyncClient 发送请求。
|
||||
|
||||
签名方式:HRESS Authorization Header
|
||||
参考:火绒终端安全管理系统API说明文档 v1
|
||||
|
||||
Attributes:
|
||||
access_key_id: 火绒 AccessKey ID(控制中心显示为 Secret ID)
|
||||
access_key_secret: 火绒 AccessKey Secret(控制中心显示为 Secret Key)
|
||||
base_url: 火绒API内网地址(如 http://huorong.oa.servyou-it.com:8080)
|
||||
timeout: 请求超时秒数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
base_url: str,
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""初始化火绒API客户端。
|
||||
|
||||
Args:
|
||||
access_key_id: 火绒 AccessKey ID(控制中心显示为 Secret ID)
|
||||
access_key_secret: 火绒 AccessKey Secret(控制中心显示为 Secret Key)
|
||||
base_url: 火绒API内网地址(不含尾部斜杠)
|
||||
timeout: 请求超时秒数,默认3秒
|
||||
"""
|
||||
self.access_key_id = access_key_id
|
||||
self.access_key_secret = access_key_secret
|
||||
# 确保 base_url 不以 / 结尾,拼接路径时统一加 /
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
# ======================================================================
|
||||
# 签名实现(HRESS Authorization Header 方式)
|
||||
# ======================================================================
|
||||
|
||||
def _compute_content_md5(self, body_bytes: bytes) -> str:
|
||||
"""计算请求体的 Content-MD5(RFC2616)。
|
||||
|
||||
算法步骤:
|
||||
1. 计算请求体的 MD5 二进制摘要(128位)
|
||||
2. 对二进制摘要进行 base64 编码
|
||||
|
||||
注意:不是对32位十六进制字符串编码,而是对原始二进制摘要编码。
|
||||
|
||||
Args:
|
||||
body_bytes: 请求体的字节内容
|
||||
|
||||
Returns:
|
||||
str: base64 编码的 MD5 摘要
|
||||
"""
|
||||
md5_digest = hashlib.md5(body_bytes).digest()
|
||||
return base64.b64encode(md5_digest).decode("utf-8")
|
||||
|
||||
def _build_canonicalized_resource(self, path: str) -> str:
|
||||
"""构建 CanonicalizedResource。
|
||||
|
||||
根据火绒API文档:
|
||||
1. 将 CanonicalizedResource 置为空字符串
|
||||
2. 设置要访问的资源路径(去掉前导 /),如 "api/clnts/_list"
|
||||
3. 如果请求包含子资源(查询参数),按字典序排列,
|
||||
以 & 为分隔符生成子资源字符串,末尾添加 ? 和子资源字符串
|
||||
|
||||
示例:
|
||||
- /api/clnts/_list → "api/clnts/_list"
|
||||
- /api/group/_info?group_id=1 → "api/group/_info?group_id=1"
|
||||
|
||||
Args:
|
||||
path: 请求路径(如 /api/clnts/_list)
|
||||
|
||||
Returns:
|
||||
str: CanonicalizedResource 字符串
|
||||
"""
|
||||
# 去掉前导 /
|
||||
return path.lstrip("/")
|
||||
|
||||
def _sign_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body_bytes: bytes = b"",
|
||||
) -> Dict[str, str]:
|
||||
"""生成火绒API请求签名(HRESS Authorization Header 方式)。
|
||||
|
||||
签名算法(来自火绒官方API文档):
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Authorization = "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature │
|
||||
│ │
|
||||
│ Signature = urlencode(base64(hmac-sha1(AccessKeySecret, │
|
||||
│ AccessKeyId + "\\n" │
|
||||
│ + Expires + "\\n" │
|
||||
│ + HTTP-METHOD + "\\n" │
|
||||
│ + Content-MD5 + "\\n" │
|
||||
│ + CanonicalizedResource))) │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
|
||||
其中:
|
||||
- AccessKeyId: 标识用户身份
|
||||
- Expires: Unix 时间戳,签名过期时间(当前时间 + 300秒)
|
||||
- HTTP-METHOD: POST(火绒API统一使用 POST)
|
||||
- Content-MD5: 请求体的 RFC2616 MD5(base64编码)
|
||||
- CanonicalizedResource: API资源路径(去掉前导/)
|
||||
|
||||
Args:
|
||||
method: HTTP方法(POST)
|
||||
path: 请求路径(如 /api/clnts/_list)
|
||||
body_bytes: 请求体字节内容
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 包含 Authorization 和 Content-Type 的 Header 字典
|
||||
"""
|
||||
# 1. 计算过期时间(Unix时间戳)
|
||||
expires = str(int(time.time()) + SIGN_EXPIRES_SECONDS)
|
||||
|
||||
# 2. 计算 Content-MD5(RFC2616: MD5 二进制摘要 → base64)
|
||||
content_md5 = self._compute_content_md5(body_bytes) if body_bytes else ""
|
||||
|
||||
# 3. 构建 CanonicalizedResource(去掉前导 /)
|
||||
canonicalized_resource = self._build_canonicalized_resource(path)
|
||||
|
||||
# 4. 构建签名字符串
|
||||
string_to_sign = (
|
||||
self.access_key_id + "\n"
|
||||
+ expires + "\n"
|
||||
+ method + "\n"
|
||||
+ content_md5 + "\n"
|
||||
+ canonicalized_resource
|
||||
)
|
||||
|
||||
# 5. HMAC-SHA1 签名 → base64 编码 → URL 编码
|
||||
signature_raw = hmac.new(
|
||||
self.access_key_secret.encode("utf-8"), # 密钥
|
||||
string_to_sign.encode("utf-8"), # 待签名字符串
|
||||
hashlib.sha1, # 算法
|
||||
).digest()
|
||||
signature_b64 = base64.b64encode(signature_raw).decode("utf-8")
|
||||
signature_encoded = quote(signature_b64, safe="")
|
||||
|
||||
# 6. 拼接 Authorization Header
|
||||
# 格式: "HRESS" + AccessKeyId + ":" + Expires + ":" + Signature
|
||||
authorization = f"HRESS{self.access_key_id}:{expires}:{signature_encoded}"
|
||||
|
||||
return {
|
||||
"Authorization": authorization,
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
}
|
||||
|
||||
# ======================================================================
|
||||
# 通用请求方法
|
||||
# ======================================================================
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
path: str,
|
||||
body: Optional[Dict[str, Any]] = None,
|
||||
) -> HuorongApiResponse:
|
||||
"""发送签名请求到火绒API。
|
||||
|
||||
统一处理:
|
||||
1. HRESS 签名 Authorization Header 生成
|
||||
2. HTTP请求发送(POST,超时控制)
|
||||
3. 响应解析和错误码处理
|
||||
4. 异常分类(认证/连接/API业务错误)
|
||||
|
||||
Args:
|
||||
path: API路径(如 /api/clnts/_list)
|
||||
body: 请求体字典(可选)
|
||||
|
||||
Returns:
|
||||
HuorongApiResponse: 火绒API响应
|
||||
|
||||
Raises:
|
||||
HuorongConnectionError: 网络不通或超时
|
||||
HuorongAuthError: 签名验证失败
|
||||
HuorongApiError: 火绒API返回业务错误
|
||||
"""
|
||||
# 构建完整URL
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
# 序列化请求体为字节(签名基于字节内容)
|
||||
body_bytes = json.dumps(body, separators=(",", ":")).encode("utf-8") if body else b"{}"
|
||||
|
||||
# 生成签名 Header
|
||||
headers = self._sign_request("POST", path, body_bytes)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
logger.debug(
|
||||
f"火绒API请求: POST {url}\n"
|
||||
f" AccessKeyID: {self.access_key_id}\n"
|
||||
f" Path: {path}\n"
|
||||
f" Body: {body_bytes[:200].decode('utf-8', errors='replace')}\n"
|
||||
f" Authorization: {headers.get('Authorization', 'N/A')[:60]}..."
|
||||
)
|
||||
response = await client.post(url, headers=headers, content=body_bytes)
|
||||
|
||||
# HTTP层面错误
|
||||
if response.status_code == 401:
|
||||
raise HuorongAuthError()
|
||||
if response.status_code != 200:
|
||||
raise HuorongApiError(
|
||||
code=response.status_code,
|
||||
message=f"HTTP {response.status_code}: {response.text[:200]}",
|
||||
)
|
||||
|
||||
# 解析JSON响应
|
||||
resp_data = response.json()
|
||||
api_resp = HuorongApiResponse(**resp_data)
|
||||
|
||||
# 火绒业务错误码处理
|
||||
# 官方文档定义的错误码:
|
||||
# - errno=0: 成功
|
||||
# - errno=1: 认证失败
|
||||
# - errno=2: 参数错误
|
||||
# - errno=3: 服务器内部错误
|
||||
# - errno=4: API未授权
|
||||
if api_resp.errcode != 0:
|
||||
if api_resp.errcode == 1 or api_resp.errcode in (401, 403):
|
||||
raise HuorongAuthError(f"认证/权限失败: {api_resp.errmsg}")
|
||||
if api_resp.errcode == 4:
|
||||
raise HuorongApiError(
|
||||
code=api_resp.errcode,
|
||||
message=f"API未授权: {api_resp.errmsg}",
|
||||
)
|
||||
raise HuorongApiError(
|
||||
code=api_resp.errcode,
|
||||
message=api_resp.errmsg,
|
||||
)
|
||||
|
||||
return api_resp
|
||||
|
||||
except httpx.TimeoutException:
|
||||
raise HuorongConnectionError(f"火绒API请求超时({self.timeout}秒): {url}")
|
||||
except httpx.ConnectError:
|
||||
raise HuorongConnectionError(f"无法连接火绒服务器: {url}")
|
||||
except (HuorongAuthError, HuorongApiError, HuorongConnectionError):
|
||||
# 已分类异常,直接向上抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
# 未预期异常,包装为通用错误
|
||||
logger.error(f"火绒API未预期异常: {type(e).__name__}: {e}")
|
||||
raise HuorongError(code=-1, message=f"火绒API调用异常: {e}")
|
||||
|
||||
# ======================================================================
|
||||
# P0 接口:查询能力
|
||||
# ======================================================================
|
||||
|
||||
async def list_terminals(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""查询终端基本信息列表。
|
||||
|
||||
火绒API: POST /api/clnts/_list
|
||||
官方参数: limit(每页条数, 默认15, 最大200) + offset(起始索引, 默认0)
|
||||
本方法将 page/per_page 转换为 limit/offset,保持外部接口一致。
|
||||
|
||||
Args:
|
||||
group_id: 分组ID(可选,不传则查全部分组)
|
||||
page: 页码(从1开始,内部转换为offset)
|
||||
per_page: 每页条数(内部转换为limit)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 total(总数) 和 items(TerminalBasicInfo列表)
|
||||
"""
|
||||
# 火绒API使用 limit/offset 分页,不是 page/per_page
|
||||
limit = min(per_page, 200) # 火绒限制最大200
|
||||
offset = (page - 1) * limit
|
||||
|
||||
body: Dict[str, Any] = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if group_id:
|
||||
body["group_id"] = int(group_id)
|
||||
|
||||
resp = await self._request("/api/clnts/_list", body)
|
||||
|
||||
# 解析响应数据
|
||||
data = resp.data or {}
|
||||
raw_items = data.get("list", [])
|
||||
total = data.get("total", 0)
|
||||
|
||||
items = [TerminalBasicInfo(**item) for item in raw_items]
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
async def get_terminal_detail(
|
||||
self,
|
||||
client_id: str,
|
||||
optional_fields: Optional[List[str]] = None,
|
||||
) -> TerminalDetailV2:
|
||||
"""获取终端详细信息v2。
|
||||
|
||||
火绒API: POST /api/clnts/_info2
|
||||
用途:获取终端硬件/软件/资产/网络配置等详细信息
|
||||
|
||||
Args:
|
||||
client_id: 终端唯一ID
|
||||
optional_fields: 需要返回的可选信息块
|
||||
可选值: hardware, software, assets, netconf
|
||||
默认全部返回
|
||||
|
||||
Returns:
|
||||
TerminalDetailV2: 终端详细信息
|
||||
"""
|
||||
if optional_fields is None:
|
||||
optional_fields = ["hardware", "software", "assets", "netconf"]
|
||||
|
||||
body = {
|
||||
"client_id": client_id,
|
||||
"optional_fields": optional_fields,
|
||||
}
|
||||
|
||||
resp = await self._request("/api/clnts/_info2", body)
|
||||
data = resp.data or {}
|
||||
|
||||
return TerminalDetailV2(**data)
|
||||
|
||||
async def list_terminal_leaks(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
page: int = 1,
|
||||
per_page: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""查询存在高危漏洞未修复的终端。
|
||||
|
||||
火绒API: POST /api/clnts/_leak
|
||||
官方参数: limit(每页条数, 默认15, 最大200) + offset(起始索引, 默认0)
|
||||
说明:返回的是"存在高危漏洞的终端列表",不是漏洞详情。
|
||||
每条记录是终端信息,字段名与 _list 不同:
|
||||
- cid (非 client_id)
|
||||
- hostname (非 computer_name)
|
||||
- ip_addr (非 local_ip)
|
||||
- stat (1=离线,2=在线,3=异常,非 is_online 布尔值)
|
||||
外层有 all_client(终端总数)和 risk_client(高危终端数)统计。
|
||||
|
||||
Args:
|
||||
group_id: 分组ID(可选,不传则查全部分组)
|
||||
page: 页码(从1开始,内部转换为offset)
|
||||
per_page: 每页条数(内部转换为limit)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 total(高危终端总数), risk_client(高危终端数),
|
||||
all_client(全部终端数) 和 items(TerminalLeakInfo列表)
|
||||
"""
|
||||
# 火绒API使用 limit/offset 分页
|
||||
limit = min(per_page, 200)
|
||||
offset = (page - 1) * limit
|
||||
|
||||
body: Dict[str, Any] = {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if group_id:
|
||||
body["group_id"] = int(group_id)
|
||||
|
||||
resp = await self._request("/api/clnts/_leak", body)
|
||||
|
||||
# 解析响应数据
|
||||
data = resp.data or {}
|
||||
raw_items = data.get("list", [])
|
||||
# _leak 不返回 total,但有 all_client 和 risk_client 统计
|
||||
all_client = data.get("all_client", 0)
|
||||
risk_client = data.get("risk_client", 0)
|
||||
|
||||
items = [TerminalLeakInfo(**item) for item in raw_items]
|
||||
|
||||
return {
|
||||
"total": risk_client, # 高危终端总数 = risk_client
|
||||
"all_client": all_client, # 全部终端数
|
||||
"risk_client": risk_client, # 高危终端数
|
||||
"items": items,
|
||||
}
|
||||
|
||||
async def get_virus_events(
|
||||
self,
|
||||
client_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
query_type: int = 2,
|
||||
begin_time: Optional[int] = None,
|
||||
end_time: Optional[int] = None,
|
||||
page: int = 1,
|
||||
per_page: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""查询终端病毒事件统计。
|
||||
|
||||
火绒API: POST /api/clnts/_virus_events
|
||||
官方参数:
|
||||
- type: 查询类型(必填)
|
||||
0=使用终端唯一标识查询(client_id字段必填)
|
||||
1=使用分组ID查询(group_id字段必填)
|
||||
2=查询全部终端日志(client_id和group_id字段可忽略)
|
||||
- client_id: 终端唯一标识
|
||||
- group_id: 分组ID
|
||||
- begin_time/end_time: 日志范围时间(Unix时间戳,默认全部时间)
|
||||
- limit/offset: 分页参数
|
||||
说明:返回终端维度的病毒日志统计,含 count(总数) 和
|
||||
result{success/fail/ignored/trusted}(处理结果明细)。
|
||||
|
||||
Args:
|
||||
client_id: 终端唯一ID(type=0时必填)
|
||||
group_id: 分组ID(type=1时必填)
|
||||
query_type: 查询类型,默认2(查全部)
|
||||
begin_time: 日志开始时间(Unix时间戳,可选)
|
||||
end_time: 日志结束时间(Unix时间戳,可选)
|
||||
page: 页码(从1开始,内部转换为offset)
|
||||
per_page: 每页条数(内部转换为limit)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 total(总数) 和 items(VirusEventStats列表)
|
||||
"""
|
||||
# 火绒API使用 limit/offset 分页
|
||||
limit = min(per_page, 200)
|
||||
offset = (page - 1) * limit
|
||||
|
||||
body: Dict[str, Any] = {
|
||||
"type": query_type,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
# 根据查询类型添加可选参数
|
||||
if query_type == 0 and client_id:
|
||||
body["client_id"] = client_id
|
||||
if query_type in (0, 1) and group_id:
|
||||
body["group_id"] = int(group_id)
|
||||
# 时间范围过滤
|
||||
if begin_time:
|
||||
body["begin_time"] = begin_time
|
||||
if end_time:
|
||||
body["end_time"] = end_time
|
||||
|
||||
resp = await self._request("/api/clnts/_virus_events", body)
|
||||
|
||||
data = resp.data or {}
|
||||
raw_items = data.get("list", [])
|
||||
total = data.get("total", 0)
|
||||
|
||||
items = [VirusEventStats(**item) for item in raw_items]
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"items": items,
|
||||
}
|
||||
|
||||
# ======================================================================
|
||||
# P1 接口:控制能力
|
||||
# ======================================================================
|
||||
|
||||
async def isolate_terminal(
|
||||
self,
|
||||
client_ids: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""隔离终端(断网)。
|
||||
|
||||
火绒API: POST /api/task/_create (type=netctrl, net_isolation=true)
|
||||
安全等级: 🔴 高危操作,调用方必须确保:
|
||||
1. 仅 admin 角色可调用
|
||||
2. 已完成二次确认
|
||||
3. 已记录操作原因
|
||||
|
||||
Args:
|
||||
client_ids: 目标终端ID列表
|
||||
|
||||
Returns:
|
||||
Dict: 火绒API响应的data部分
|
||||
"""
|
||||
body = {
|
||||
"type": "netctrl",
|
||||
"net_isolation": True,
|
||||
"clients": client_ids,
|
||||
}
|
||||
|
||||
resp = await self._request("/api/task/_create", body)
|
||||
logger.warning(f"火绒终端隔离操作: client_ids={client_ids}")
|
||||
return resp.data or {}
|
||||
|
||||
async def unisolate_terminal(
|
||||
self,
|
||||
client_ids: List[str],
|
||||
) -> Dict[str, Any]:
|
||||
"""解除终端隔离(恢复网络)。
|
||||
|
||||
火绒API: POST /api/task/_create (type=netctrl, net_isolation=false)
|
||||
|
||||
Args:
|
||||
client_ids: 目标终端ID列表
|
||||
|
||||
Returns:
|
||||
Dict: 火绒API响应的data部分
|
||||
"""
|
||||
body = {
|
||||
"type": "netctrl",
|
||||
"net_isolation": False,
|
||||
"clients": client_ids,
|
||||
}
|
||||
|
||||
resp = await self._request("/api/task/_create", body)
|
||||
logger.info(f"火绒终端解除隔离: client_ids={client_ids}")
|
||||
return resp.data or {}
|
||||
|
||||
async def create_scan_task(
|
||||
self,
|
||||
client_ids: List[str],
|
||||
scan_type: str = "quick_scan",
|
||||
) -> Dict[str, Any]:
|
||||
"""创建终端扫描任务。
|
||||
|
||||
火绒API: POST /api/task/_create
|
||||
扫描类型: quick_scan(快速扫描) / full_scan(全盘扫描) / custom_scan(自定义扫描)
|
||||
|
||||
Args:
|
||||
client_ids: 目标终端ID列表
|
||||
scan_type: 扫描类型,默认快速扫描
|
||||
|
||||
Returns:
|
||||
Dict: 火绒API响应的data部分
|
||||
"""
|
||||
body = {
|
||||
"type": scan_type,
|
||||
"clients": client_ids,
|
||||
}
|
||||
|
||||
resp = await self._request("/api/task/_create", body)
|
||||
logger.info(f"火绒终端扫描任务: type={scan_type}, client_ids={client_ids}")
|
||||
return resp.data or {}
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
client_ids: List[str],
|
||||
content: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""向终端发送通知。
|
||||
|
||||
火绒API: POST /api/task/_create (type=message)
|
||||
|
||||
Args:
|
||||
client_ids: 目标终端ID列表
|
||||
content: 通知内容
|
||||
|
||||
Returns:
|
||||
Dict: 火绒API响应的data部分
|
||||
"""
|
||||
body = {
|
||||
"type": "message",
|
||||
"clients": client_ids,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
resp = await self._request("/api/task/_create", body)
|
||||
logger.info(f"火绒终端通知: client_ids={client_ids}, content={content[:50]}")
|
||||
return resp.data or {}
|
||||
|
||||
# ======================================================================
|
||||
# 测试连接
|
||||
# ======================================================================
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试火绒API连接是否正常。
|
||||
|
||||
使用 _list 接口(page=1, per_page=1)进行轻量级连接测试,
|
||||
验证签名是否正确、网络是否可达。
|
||||
|
||||
Returns:
|
||||
Dict: 包含 success(bool) 和 message(str)
|
||||
"""
|
||||
try:
|
||||
result = await self.list_terminals(page=1, per_page=1)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"连接成功,共 {result.get('total', 0)} 个终端",
|
||||
"total_terminals": result.get("total", 0),
|
||||
}
|
||||
except HuorongAuthError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"认证失败: {e.message}",
|
||||
}
|
||||
except HuorongConnectionError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"连接失败: {e.message}",
|
||||
}
|
||||
except HuorongError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"测试失败: {e.message}",
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 火绒集成配置管理
|
||||
# =============================================================================
|
||||
# 说明:从系统配置表(system_configs)读取火绒 AccessKey/Secret/BaseUrl,
|
||||
# 构建火绒API客户端实例
|
||||
# =============================================================================
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.system_config import SystemConfig
|
||||
from .client import HuorongClient
|
||||
from .exceptions import HuorongConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 火绒配置在 system_configs 表中的 key 前缀
|
||||
HUORONG_CONFIG_PREFIX = "integration_huorong_"
|
||||
|
||||
|
||||
async def get_huorong_client(db: AsyncSession) -> HuorongClient:
|
||||
"""从系统配置表构建火绒API客户端。
|
||||
|
||||
读取 integration_huorong_ 前缀的配置项,构建 HuorongClient 实例。
|
||||
如果任何必填配置缺失,抛出 HuorongConfigError。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
HuorongClient: 已配置的火绒API客户端实例
|
||||
|
||||
Raises:
|
||||
HuorongConfigError: AccessKey ID/Secret/Base URL 任一缺失
|
||||
"""
|
||||
# 读取三个必填配置
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(
|
||||
SystemConfig.config_key.startswith(HUORONG_CONFIG_PREFIX)
|
||||
)
|
||||
)
|
||||
configs = list(result.scalars().all())
|
||||
|
||||
# 构建 key→value 映射
|
||||
config_map = {cfg.config_key: cfg.config_value for cfg in configs}
|
||||
|
||||
access_key_id = config_map.get(f"{HUORONG_CONFIG_PREFIX}access_key_id", "")
|
||||
access_key_secret = config_map.get(f"{HUORONG_CONFIG_PREFIX}access_key_secret", "")
|
||||
base_url = config_map.get(f"{HUORONG_CONFIG_PREFIX}base_url", "")
|
||||
|
||||
# 校验必填项
|
||||
if not access_key_id or not access_key_secret or not base_url:
|
||||
missing = []
|
||||
if not access_key_id:
|
||||
missing.append("AccessKey ID")
|
||||
if not access_key_secret:
|
||||
missing.append("AccessKey Secret")
|
||||
if not base_url:
|
||||
missing.append("Base URL")
|
||||
raise HuorongConfigError(
|
||||
f"火绒集成配置不完整,缺失: {', '.join(missing)},请先在管理后台完成配置"
|
||||
)
|
||||
|
||||
return HuorongClient(
|
||||
access_key_id=access_key_id,
|
||||
access_key_secret=access_key_secret,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
|
||||
async def is_huorong_configured(db: AsyncSession) -> bool:
|
||||
"""检查火绒集成是否已完整配置。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
bool: 三项配置均存在且非空时返回 True
|
||||
"""
|
||||
try:
|
||||
client = await get_huorong_client(db)
|
||||
return bool(client.access_key_id and client.access_key_secret and client.base_url)
|
||||
except HuorongConfigError:
|
||||
return False
|
||||
@@ -0,0 +1,63 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 火绒集成自定义异常
|
||||
# =============================================================================
|
||||
# 说明:火绒API调用中可能抛出的各种异常类型
|
||||
# 包含:认证错误、连接超时、API错误码等
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class HuorongError(Exception):
|
||||
"""火绒集成基础异常。
|
||||
|
||||
所有火绒相关异常的父类,便于统一捕获处理。
|
||||
|
||||
Attributes:
|
||||
code: 错误码(火绒API返回的errcode,或自定义错误码)
|
||||
message: 错误描述
|
||||
"""
|
||||
|
||||
def __init__(self, code: int = -1, message: str = "火绒API调用失败"):
|
||||
self.code = code
|
||||
self.message = message
|
||||
super().__init__(f"[HuorongError:{code}] {message}")
|
||||
|
||||
|
||||
class HuorongAuthError(HuorongError):
|
||||
"""火绒认证失败异常。
|
||||
|
||||
场景:AccessKey ID/Secret 无效、签名校验失败、权限不足
|
||||
火绒API返回 errcode=401 或签名相关错误时抛出。
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "火绒API认证失败,请检查AccessKey配置"):
|
||||
super().__init__(code=401, message=message)
|
||||
|
||||
|
||||
class HuorongConnectionError(HuorongError):
|
||||
"""火绒连接失败异常。
|
||||
|
||||
场景:内网地址不通、超时、DNS解析失败
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "无法连接火绒服务器,请检查网络和Base URL配置"):
|
||||
super().__init__(code=502, message=message)
|
||||
|
||||
|
||||
class HuorongConfigError(HuorongError):
|
||||
"""火绒配置缺失异常。
|
||||
|
||||
场景:AccessKey ID/Secret/Base URL 未在系统配置中设置
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "火绒集成未配置,请先在管理后台设置AccessKey和Base URL"):
|
||||
super().__init__(code=400, message=message)
|
||||
|
||||
|
||||
class HuorongApiError(HuorongError):
|
||||
"""火绒API业务错误。
|
||||
|
||||
场景:火绒API返回非0 errcode(如参数错误、终端不存在等)
|
||||
"""
|
||||
|
||||
def __init__(self, code: int, message: str):
|
||||
super().__init__(code=code, message=message)
|
||||
@@ -0,0 +1,373 @@
|
||||
# =============================================================================
|
||||
# 企微IT智能服务台 — 火绒集成数据模型
|
||||
# =============================================================================
|
||||
# 说明:火绒API请求/响应的 Pydantic 数据模型
|
||||
# 包含:终端信息、漏洞信息、病毒事件、任务下发等
|
||||
# =============================================================================
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 通用响应模型
|
||||
# ==========================================================================
|
||||
|
||||
class HuorongApiResponse(BaseModel):
|
||||
"""火绒API统一响应模型。
|
||||
|
||||
火绒所有API返回格式一致(官方API文档 v1):
|
||||
成功时: { "errno": 0, "errmsg": "", "data": { ... } }
|
||||
失败时: { "errno": 1, "errmsg": "Authentication failed" }
|
||||
|
||||
官方错误码定义:
|
||||
- errno=0: 成功
|
||||
- errno=1: 认证失败
|
||||
- errno=2: 参数错误
|
||||
- errno=3: 服务器内部错误
|
||||
- errno=4: API未授权
|
||||
|
||||
注意:火绒API始终使用 errno(不是 errcode)。
|
||||
使用 model_validator 在验证前将 errno 归一化为 errcode,
|
||||
保持内部代码统一使用 errcode 字段。
|
||||
|
||||
Attributes:
|
||||
errcode: 错误码,0表示成功(从 errno 归一化而来)
|
||||
errmsg: 错误描述(成功时为空字符串)
|
||||
data: 业务数据(成功时非None)
|
||||
"""
|
||||
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def normalize_error_fields(cls, data: Any) -> Any:
|
||||
"""将火绒API返回的 errno 字段归一化为 errcode。
|
||||
|
||||
火绒API在认证失败等错误场景下返回 errno 而非 errcode,
|
||||
此验证器在 Pydantic 字段校验前将 errno 转换为 errcode,
|
||||
统一后续处理逻辑。
|
||||
|
||||
Args:
|
||||
data: 原始输入数据(通常为dict)
|
||||
|
||||
Returns:
|
||||
归一化后的数据
|
||||
"""
|
||||
if isinstance(data, dict) and 'errno' in data and 'errcode' not in data:
|
||||
data['errcode'] = data.pop('errno')
|
||||
return data
|
||||
|
||||
errcode: int = Field(..., description="错误码,0=成功")
|
||||
errmsg: str = Field(default="ok", description="错误描述")
|
||||
data: Optional[Any] = Field(default=None, description="业务数据")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端基本信息 — /api/clnts/_list 返回
|
||||
# ==========================================================================
|
||||
|
||||
class TerminalBasicInfo(BaseModel):
|
||||
"""终端基本信息(_list 接口返回的每条记录)。
|
||||
|
||||
字段名严格按照火绒API文档实际返回值定义。
|
||||
注意:API返回的字段名与之前猜测不同,已根据官方文档修正。
|
||||
|
||||
Attributes:
|
||||
id: 内部数据库ID
|
||||
client_id: 终端唯一ID(40位十六进制字符串,用于所有任务下发)
|
||||
client_name: 客户端名称
|
||||
computer_name: 计算机名
|
||||
local_ip: 本地IP
|
||||
connect_ip: 连接IP(客户端连接控制中心使用的IP)
|
||||
mac: MAC地址
|
||||
group_id: 分组ID
|
||||
os_version: 操作系统版本
|
||||
version: 火绒客户端版本
|
||||
definitions: 病毒库更新时间
|
||||
is_online: 在线状态
|
||||
last_connect_time: 最后连接时间(Unix时间戳)
|
||||
last_seen_time: 最后可见时间(Unix时间戳)
|
||||
first_appear_time: 首次出现时间(Unix时间戳)
|
||||
"""
|
||||
id: Optional[int] = Field(default=None, description="内部数据库ID")
|
||||
client_id: str = Field(..., description="终端唯一ID")
|
||||
client_name: str = Field(default="", description="客户端名称")
|
||||
computer_name: str = Field(default="", description="计算机名")
|
||||
local_ip: str = Field(default="", description="本地IP")
|
||||
connect_ip: str = Field(default="", description="连接IP")
|
||||
mac: str = Field(default="", description="MAC地址")
|
||||
group_id: Optional[Any] = Field(default=None, description="分组ID(int或str)")
|
||||
os_version: str = Field(default="", description="操作系统版本")
|
||||
version: str = Field(default="", description="火绒客户端版本")
|
||||
definitions: str = Field(default="", description="病毒库更新时间")
|
||||
is_online: bool = Field(default=False, description="在线状态")
|
||||
last_connect_time: Optional[int] = Field(default=None, description="最后连接时间")
|
||||
last_seen_time: Optional[int] = Field(default=None, description="最后可见时间")
|
||||
first_appear_time: Optional[int] = Field(default=None, description="首次出现时间")
|
||||
|
||||
|
||||
class TerminalListRequest(BaseModel):
|
||||
"""终端列表查询请求。
|
||||
|
||||
Attributes:
|
||||
group_id: 分组ID(可选,不传则查全部分组)
|
||||
page: 页码(从1开始)
|
||||
per_page: 每页条数
|
||||
"""
|
||||
group_id: Optional[str] = Field(default=None, description="分组ID")
|
||||
page: int = Field(default=1, ge=1, description="页码")
|
||||
per_page: int = Field(default=20, ge=1, le=100, description="每页条数")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端详细信息v2 — /api/clnts/_info2 返回
|
||||
# ==========================================================================
|
||||
|
||||
class HardwareInfo(BaseModel):
|
||||
"""终端硬件信息。
|
||||
|
||||
Attributes:
|
||||
cpu: CPU信息
|
||||
memory: 内存信息
|
||||
disk: 磁盘信息
|
||||
motherboard: 主板信息
|
||||
network_card: 网卡信息
|
||||
"""
|
||||
cpu: str = Field(default="", description="CPU信息")
|
||||
memory: str = Field(default="", description="内存信息")
|
||||
disk: str = Field(default="", description="磁盘信息")
|
||||
motherboard: str = Field(default="", description="主板信息")
|
||||
network_card: str = Field(default="", description="网卡信息")
|
||||
|
||||
|
||||
class SoftwareInfo(BaseModel):
|
||||
"""已安装软件条目。
|
||||
|
||||
Attributes:
|
||||
name: 软件名称
|
||||
version: 版本号
|
||||
publisher: 发布者
|
||||
"""
|
||||
name: str = Field(default="", description="软件名称")
|
||||
version: str = Field(default="", description="版本号")
|
||||
publisher: str = Field(default="", description="发布者")
|
||||
|
||||
|
||||
class AssetInfo(BaseModel):
|
||||
"""资产信息。
|
||||
|
||||
Attributes:
|
||||
asset_tag: 资产标签
|
||||
serial_number: 序列号
|
||||
"""
|
||||
asset_tag: str = Field(default="", description="资产标签")
|
||||
serial_number: str = Field(default="", description="序列号")
|
||||
|
||||
|
||||
class NetworkConfig(BaseModel):
|
||||
"""网络配置信息。
|
||||
|
||||
Attributes:
|
||||
ip: IP地址
|
||||
gateway: 网关
|
||||
dns: DNS服务器
|
||||
adapter_info: 网卡适配器信息
|
||||
"""
|
||||
ip: str = Field(default="", description="IP地址")
|
||||
gateway: str = Field(default="", description="网关")
|
||||
dns: str = Field(default="", description="DNS服务器")
|
||||
adapter_info: str = Field(default="", description="网卡适配器信息")
|
||||
|
||||
|
||||
class TerminalDetailV2(BaseModel):
|
||||
"""终端详细信息v2(_info2 接口返回)。
|
||||
|
||||
通过 optional_fields 参数指定需要返回的信息块:
|
||||
- hardware: 硬件信息
|
||||
- software: 已安装软件
|
||||
- assets: 资产信息
|
||||
- netconf: 网络配置
|
||||
|
||||
Attributes:
|
||||
client_id: 终端唯一ID
|
||||
computer_name: 计算机名
|
||||
hardware: 硬件信息(可选)
|
||||
software: 已安装软件列表(可选)
|
||||
assets: 资产信息(可选)
|
||||
netconf: 网络配置(可选)
|
||||
"""
|
||||
client_id: str = Field(..., description="终端唯一ID")
|
||||
computer_name: str = Field(default="", description="计算机名")
|
||||
hardware: Optional[HardwareInfo] = Field(default=None, description="硬件信息")
|
||||
software: Optional[List[SoftwareInfo]] = Field(default=None, description="已安装软件")
|
||||
assets: Optional[AssetInfo] = Field(default=None, description="资产信息")
|
||||
netconf: Optional[NetworkConfig] = Field(default=None, description="网络配置")
|
||||
|
||||
|
||||
class TerminalDetailRequest(BaseModel):
|
||||
"""终端详细信息查询请求。
|
||||
|
||||
Attributes:
|
||||
client_id: 终端唯一ID
|
||||
optional_fields: 需要返回的可选信息块列表
|
||||
"""
|
||||
client_id: str = Field(..., description="终端唯一ID")
|
||||
optional_fields: List[str] = Field(
|
||||
default_factory=lambda: ["hardware", "software", "assets", "netconf"],
|
||||
description="可选信息块: hardware/software/assets/netconf",
|
||||
)
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 漏洞信息 — /api/clnts/_leak 返回
|
||||
# 说明:_leak 接口返回的是"存在高危漏洞未修复的终端列表",
|
||||
# 每条记录是终端信息(非漏洞详情),API不返回具体漏洞CVE列表。
|
||||
# 外层还有 all_client(终端总数)和 risk_client(高危终端数)统计。
|
||||
# ==========================================================================
|
||||
|
||||
class TerminalLeakInfo(BaseModel):
|
||||
"""存在高危漏洞的终端信息(_leak 接口返回的每条记录)。
|
||||
|
||||
注意:_leak 返回的是终端维度数据,不是漏洞维度。
|
||||
字段名严格按照火绒API文档实际返回值定义。
|
||||
与 _list 接口的字段名不同!
|
||||
|
||||
Attributes:
|
||||
cid: 终端唯一ID(_leak 中叫 cid,_list 中叫 client_id)
|
||||
hostname: 计算机名(_leak 中叫 hostname,_list 中叫 computer_name)
|
||||
client_name: 终端名称
|
||||
group_name: 分组名称
|
||||
group_id: 分组ID
|
||||
ip_addr: 本地IP(_leak 中叫 ip_addr,_list 中叫 local_ip)
|
||||
call_ip: 连接IP(_leak 中叫 call_ip,_list 中叫 connect_ip)
|
||||
mac: MAC地址
|
||||
osver: 操作系统版本(_leak 中叫 osver,_list 中叫 os_version)
|
||||
os_type: 终端类型(如 Windows)
|
||||
prodver: 火绒客户端版本(_leak 中叫 prodver,_list 中叫 version)
|
||||
virdb: 病毒库版本(Unix时间戳,_leak 中叫 virdb,_list 中叫 definitions)
|
||||
stat: 在线状态码(1=离线, 2=在线, 3=异常,_list 中是 is_online 布尔值)
|
||||
"""
|
||||
cid: str = Field(..., description="终端唯一ID")
|
||||
hostname: str = Field(default="", description="计算机名")
|
||||
client_name: str = Field(default="", description="终端名称")
|
||||
group_name: str = Field(default="", description="分组名称")
|
||||
group_id: Optional[Any] = Field(default=None, description="分组ID")
|
||||
ip_addr: str = Field(default="", description="本地IP")
|
||||
call_ip: str = Field(default="", description="连接IP")
|
||||
mac: str = Field(default="", description="MAC地址")
|
||||
osver: str = Field(default="", description="操作系统版本")
|
||||
os_type: str = Field(default="", description="终端类型")
|
||||
prodver: str = Field(default="", description="火绒客户端版本")
|
||||
virdb: Optional[Any] = Field(default=None, description="病毒库版本(Unix时间戳)")
|
||||
stat: int = Field(default=1, description="在线状态码: 1=离线 2=在线 3=异常")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 病毒事件 — /api/clnts/_virus_events 返回
|
||||
# 说明:_virus_events 返回终端维度的病毒日志统计,
|
||||
# 含总数(count)和4种处理结果(result)的明细。
|
||||
# 请求需指定 type: 0=按client_id查, 1=按group_id查, 2=查全部
|
||||
# ==========================================================================
|
||||
|
||||
class VirusHandleResult(BaseModel):
|
||||
"""病毒事件处理结果统计。
|
||||
|
||||
Attributes:
|
||||
success: 处理成功数
|
||||
fail: 处理失败数
|
||||
ignored: 暂不处理数
|
||||
trusted: 已信任数
|
||||
"""
|
||||
success: int = Field(default=0, description="处理成功数")
|
||||
fail: int = Field(default=0, description="处理失败数")
|
||||
ignored: int = Field(default=0, description="暂不处理数")
|
||||
trusted: int = Field(default=0, description="已信任数")
|
||||
|
||||
|
||||
class VirusEventStats(BaseModel):
|
||||
"""终端病毒事件统计(_virus_events 接口返回的每条记录)。
|
||||
|
||||
字段名严格按照火绒API文档实际返回值定义。
|
||||
与 _list 接口的字段名基本一致。
|
||||
|
||||
Attributes:
|
||||
group_id: 分组ID
|
||||
client_id: 终端唯一ID
|
||||
client_name: 终端名称
|
||||
computer_name: 计算机名
|
||||
local_ip: 本地IP
|
||||
connect_ip: 连接IP
|
||||
mac: MAC地址
|
||||
count: 病毒日志总数
|
||||
result: 处理结果统计(success/fail/ignored/trusted)
|
||||
"""
|
||||
group_id: Optional[Any] = Field(default=None, description="分组ID")
|
||||
client_id: str = Field(..., description="终端唯一ID")
|
||||
client_name: str = Field(default="", description="终端名称")
|
||||
computer_name: str = Field(default="", description="计算机名")
|
||||
local_ip: str = Field(default="", description="本地IP")
|
||||
connect_ip: str = Field(default="", description="连接IP")
|
||||
mac: str = Field(default="", description="MAC地址")
|
||||
count: int = Field(default=0, description="病毒日志总数")
|
||||
result: Optional[VirusHandleResult] = Field(default=None, description="处理结果统计")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端任务 — /api/task/_create
|
||||
# ==========================================================================
|
||||
|
||||
class TaskCreateRequest(BaseModel):
|
||||
"""终端任务创建请求。
|
||||
|
||||
支持的任务类型:
|
||||
- quick_scan: 快速扫描
|
||||
- full_scan: 全盘扫描
|
||||
- custom_scan: 自定义扫描
|
||||
- netctrl: 终端隔离/解除
|
||||
- message: 发送通知
|
||||
|
||||
Attributes:
|
||||
task_type: 任务类型
|
||||
client_ids: 目标终端ID列表
|
||||
net_isolation: 是否隔离(仅 netctrl 类型有效)
|
||||
message_content: 通知内容(仅 message 类型有效)
|
||||
"""
|
||||
task_type: str = Field(..., description="任务类型: quick_scan/full_scan/custom_scan/netctrl/message")
|
||||
client_ids: List[str] = Field(..., min_length=1, description="目标终端ID列表")
|
||||
net_isolation: Optional[bool] = Field(default=None, description="是否隔离(仅netctrl类型)")
|
||||
message_content: Optional[str] = Field(default=None, description="通知内容(仅message类型)")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端安全画像(聚合模型,供前端直接使用)
|
||||
# ==========================================================================
|
||||
|
||||
class TerminalSecurityProfile(BaseModel):
|
||||
"""终端安全画像(聚合模型)。
|
||||
|
||||
将终端基本信息+安全状态聚合成一个模型,供坐席端直接展示。
|
||||
|
||||
Attributes:
|
||||
client_id: 终端唯一ID
|
||||
computer_name: 计算机名
|
||||
ip: 本地IP
|
||||
mac: MAC地址
|
||||
os_version: 操作系统版本
|
||||
is_online: 在线状态
|
||||
group_name: 分组名称
|
||||
hardware: 硬件概要
|
||||
high_risk_leaks: 高危漏洞数
|
||||
uncleaned_virus: 未处理病毒事件数
|
||||
security_score: 安全评分(0-100,综合漏洞+病毒+在线状态)
|
||||
"""
|
||||
client_id: str = Field(..., description="终端唯一ID")
|
||||
computer_name: str = Field(default="", description="计算机名")
|
||||
ip: str = Field(default="", description="本地IP")
|
||||
mac: str = Field(default="", description="MAC地址")
|
||||
os_version: str = Field(default="", description="操作系统版本")
|
||||
is_online: bool = Field(default=False, description="在线状态")
|
||||
group_name: str = Field(default="", description="分组名称")
|
||||
hardware: Optional[HardwareInfo] = Field(default=None, description="硬件概要")
|
||||
high_risk_leaks: int = Field(default=0, description="高危漏洞数")
|
||||
uncleaned_virus: int = Field(default=0, description="未处理病毒事件数")
|
||||
security_score: int = Field(default=100, description="安全评分(0-100)")
|
||||
@@ -0,0 +1,15 @@
|
||||
# 联软LV7000 API集成模块
|
||||
"""
|
||||
提供联软LV7000终端安全管理系统的API客户端。
|
||||
|
||||
认证方式:三层认证(IP白名单 + 账号密码 + Token)
|
||||
- 第一层:IP白名单(在联软后台配置WhiteListServerIp)
|
||||
- 第二层:账号密码(ApiAccount + ApiPassword)
|
||||
- 第三层:一次性Token(先调getToken获取,30分钟有效)
|
||||
|
||||
核心P0接口:
|
||||
- queryDevByParams:按条件查询终端(含strusername员工账号映射)
|
||||
- getDevAllInfo:终端详细信息(硬件+软件+资产+网络)
|
||||
- getUserInfoByAccount:按账号查用户信息
|
||||
- getAllOrgInfo:全量组织架构同步
|
||||
"""
|
||||
@@ -0,0 +1,604 @@
|
||||
# 联软LV7000 API客户端
|
||||
"""
|
||||
联软LV7000终端安全管理系统 API 客户端。
|
||||
|
||||
认证流程:
|
||||
1. 第一层:IP白名单(在联软后台配置,调用时自动生效)
|
||||
2. 第二层:账号密码(ApiAccount + ApiPassword)
|
||||
3. 第三层:Token(先调getToken获取,30分钟有效,自动缓存+刷新)
|
||||
|
||||
接口调用方式:
|
||||
- GET请求:参数通过query string传递
|
||||
- POST请求:参数通过form-data传递
|
||||
- 统一携带 token + apiAccount + apiPassword + validatekey
|
||||
|
||||
使用示例:
|
||||
client = LianruanClient(base_url, api_account, api_password, validate_key)
|
||||
terminals = await client.query_dev_by_params(strusername="songxian")
|
||||
detail = await client.get_dev_all_info(strdevname="IT-SONGXIAN")
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.integrations.lianruan.exceptions import (
|
||||
LianruanApiError,
|
||||
LianruanAuthError,
|
||||
LianruanConnectionError,
|
||||
)
|
||||
from app.integrations.lianruan.models import (
|
||||
TerminalBasicInfo,
|
||||
TerminalAllInfo,
|
||||
UserInfo,
|
||||
OrgDeptInfo,
|
||||
OnlineStatus,
|
||||
TerminalSoftwareInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LianruanClient:
|
||||
"""联软LV7000 API客户端。
|
||||
|
||||
Attributes:
|
||||
base_url: 联软API地址,如 http://192.168.x.x:30098
|
||||
api_account: API账号
|
||||
api_password: API密码
|
||||
validate_key: 验证密钥
|
||||
_token: 缓存的Token
|
||||
_token_expire: Token过期时间戳
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
api_account: str,
|
||||
api_password: str,
|
||||
validate_key: str = "",
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_account = api_account
|
||||
self.api_password = api_password
|
||||
self.validate_key = validate_key
|
||||
self.timeout = timeout
|
||||
|
||||
# Token缓存(30分钟有效,提前5分钟刷新)
|
||||
self._token: str = ""
|
||||
self._token_expire: float = 0.0
|
||||
|
||||
# httpx异步客户端(连接池复用)
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建httpx异步客户端(懒初始化+连接池复用)"""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
timeout=self.timeout,
|
||||
verify=False, # 内网自签证书
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭httpx客户端,释放连接池"""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
||||
# ==========================================================================
|
||||
# Token管理
|
||||
# ==========================================================================
|
||||
|
||||
async def _ensure_token(self) -> str:
|
||||
"""确保Token有效,过期则自动刷新。
|
||||
|
||||
联软Token默认30分钟有效,提前5分钟刷新。
|
||||
|
||||
Returns:
|
||||
str: 有效的Token字符串
|
||||
"""
|
||||
now = time.time()
|
||||
# Token还有5分钟以上有效期,直接复用
|
||||
if self._token and now < self._token_expire - 300:
|
||||
return self._token
|
||||
|
||||
# 重新获取Token
|
||||
logger.info("联软Token过期或为空,正在刷新...")
|
||||
try:
|
||||
client = await self._get_client()
|
||||
url = f"{self.base_url}/token"
|
||||
params = {
|
||||
"act": "getToken",
|
||||
"apiAccount": self.api_account,
|
||||
"apiPassword": self.api_password,
|
||||
}
|
||||
if self.validate_key:
|
||||
params["validatekey"] = self.validate_key
|
||||
|
||||
resp = await client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if data.get("status") != "SUCCESS":
|
||||
raise LianruanAuthError(
|
||||
f"获取Token失败: {data.get('msg', '未知错误')}",
|
||||
detail=str(data),
|
||||
)
|
||||
|
||||
self._token = data.get("data", data.get("rows", ""))
|
||||
if not self._token:
|
||||
# 有些版本返回格式不同
|
||||
self._token = str(data.get("token", ""))
|
||||
|
||||
# 30分钟有效期
|
||||
self._token_expire = now + 1800
|
||||
logger.info("联软Token刷新成功,有效期至 %s",
|
||||
time.strftime("%H:%M:%S", time.localtime(self._token_expire)))
|
||||
return self._token
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
raise LianruanConnectionError(
|
||||
f"无法连接联软服务器 {self.base_url}: {e}",
|
||||
detail=str(e),
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise LianruanConnectionError(
|
||||
f"连接联软服务器超时: {e}",
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# 通用请求方法
|
||||
# ==========================================================================
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
path: str,
|
||||
act: str,
|
||||
params: Optional[dict] = None,
|
||||
method: str = "GET",
|
||||
) -> dict:
|
||||
"""发送请求到联软API。
|
||||
|
||||
自动携带认证参数(token + apiAccount + apiPassword)。
|
||||
|
||||
Args:
|
||||
path: API路径,如 /terminal 或 /querydeptuser
|
||||
act: 操作类型,如 queryDevByParams
|
||||
params: 额外业务参数
|
||||
method: 请求方法(GET/POST)
|
||||
|
||||
Returns:
|
||||
dict: 联软API返回的JSON数据
|
||||
|
||||
Raises:
|
||||
LianruanAuthError: 认证失败
|
||||
LianruanApiError: 业务错误
|
||||
LianruanConnectionError: 网络错误
|
||||
"""
|
||||
token = await self._ensure_token()
|
||||
client = await self._get_client()
|
||||
|
||||
# 构建完整参数:认证参数 + 业务参数
|
||||
full_params = {
|
||||
"act": act,
|
||||
"apiAccount": self.api_account,
|
||||
"apiPassword": self.api_password,
|
||||
"token": token,
|
||||
}
|
||||
if self.validate_key:
|
||||
full_params["validatekey"] = self.validate_key
|
||||
if params:
|
||||
full_params.update(params)
|
||||
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
try:
|
||||
if method.upper() == "POST":
|
||||
resp = await client.post(url, data=full_params)
|
||||
else:
|
||||
resp = await client.get(url, params=full_params)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
raise LianruanConnectionError(
|
||||
f"无法连接联软服务器: {e}",
|
||||
detail=str(e),
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise LianruanConnectionError(
|
||||
f"请求联软超时: {e}",
|
||||
detail=str(e),
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise LianruanApiError(
|
||||
f"联软HTTP错误 {e.response.status_code}",
|
||||
status=str(e.response.status_code),
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# 检查联软业务状态码
|
||||
status = data.get("status", "")
|
||||
if status == "INVALID":
|
||||
# Token可能过期,清除缓存重试一次
|
||||
self._token = ""
|
||||
self._token_expire = 0
|
||||
raise LianruanAuthError(
|
||||
f"联软认证失败(IP不在白名单或Token无效): {data.get('msg', '')}",
|
||||
detail=str(data),
|
||||
)
|
||||
elif status == "ERROR":
|
||||
raise LianruanApiError(
|
||||
f"联软API错误: {data.get('msg', '')}",
|
||||
status=status,
|
||||
detail=str(data),
|
||||
)
|
||||
elif status == "Exceed":
|
||||
raise LianruanApiError(
|
||||
f"联软数据量超限: {data.get('msg', '')}",
|
||||
status=status,
|
||||
detail=str(data),
|
||||
)
|
||||
elif status != "SUCCESS":
|
||||
raise LianruanApiError(
|
||||
f"联软未知状态: {status} - {data.get('msg', '')}",
|
||||
status=status,
|
||||
detail=str(data),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
# ==========================================================================
|
||||
# P0接口 — 终端设备查询
|
||||
# ==========================================================================
|
||||
|
||||
async def query_dev_by_params(
|
||||
self,
|
||||
strusername: str = "",
|
||||
strdevname: str = "",
|
||||
strdevip: str = "",
|
||||
strmac: str = "",
|
||||
page: int = 1,
|
||||
per_page: int = 20,
|
||||
) -> dict:
|
||||
"""查询终端设备(核心映射接口)。
|
||||
|
||||
⭐ strusername 参数可直接按员工账号查终端,这是联软最大的优势!
|
||||
|
||||
Args:
|
||||
strusername: 员工账号(映射金钥匙)
|
||||
strdevname: 计算机名
|
||||
strdevip: IP地址
|
||||
strmac: MAC地址
|
||||
page: 页码(从1开始)
|
||||
per_page: 每页条数
|
||||
|
||||
Returns:
|
||||
dict: {"items": [TerminalBasicInfo], "total": int}
|
||||
"""
|
||||
params: dict = {}
|
||||
if strusername:
|
||||
params["strusername"] = strusername
|
||||
if strdevname:
|
||||
params["strdevname"] = strdevname
|
||||
if strdevip:
|
||||
params["strdevip"] = strdevip
|
||||
if strmac:
|
||||
params["strmac"] = strmac
|
||||
|
||||
# 联软分页参数
|
||||
params["page"] = str(page)
|
||||
params["rows"] = str(per_page)
|
||||
|
||||
data = await self._request("/terminal", "queryDevByParams", params)
|
||||
|
||||
rows = data.get("rows", [])
|
||||
total = data.get("total", len(rows))
|
||||
items = [TerminalBasicInfo(**row) for row in rows]
|
||||
|
||||
return {"items": items, "total": total}
|
||||
|
||||
async def get_dev_all_info(
|
||||
self,
|
||||
strdevname: str = "",
|
||||
strdevip: str = "",
|
||||
) -> TerminalAllInfo:
|
||||
"""查询终端详细信息(极详细硬件+软件+资产+网络)。
|
||||
|
||||
比火绒_info2更丰富,包含逻辑磁盘使用率、显示器信息、内存条详情。
|
||||
|
||||
Args:
|
||||
strdevname: 计算机名(二选一)
|
||||
strdevip: IP地址(二选一)
|
||||
|
||||
Returns:
|
||||
TerminalAllInfo: 终端详细信息
|
||||
"""
|
||||
params: dict = {}
|
||||
if strdevname:
|
||||
params["strdevname"] = strdevname
|
||||
if strdevip:
|
||||
params["strdevip"] = strdevip
|
||||
|
||||
data = await self._request(
|
||||
"/devallinfoshowwithpaging", "getDevAllInfo", params
|
||||
)
|
||||
|
||||
# 返回格式:data.equipment + data.equipmentdetail
|
||||
equipment = data.get("equipment", data.get("rows", [{}]))
|
||||
if isinstance(equipment, list) and equipment:
|
||||
equipment = equipment[0]
|
||||
|
||||
equipment_detail = data.get("equipmentdetail", {})
|
||||
if isinstance(equipment_detail, list) and equipment_detail:
|
||||
equipment_detail = equipment_detail[0]
|
||||
dev_detail = equipment_detail.get("devdetail", equipment_detail)
|
||||
|
||||
# 解析硬件详情
|
||||
result = TerminalAllInfo(
|
||||
strdevname=equipment.get("strdevname", ""),
|
||||
strip1=equipment.get("strip1", ""),
|
||||
strmac=equipment.get("strmac", ""),
|
||||
strdeptname=equipment.get("strdeptname", ""),
|
||||
strusername=equipment.get("strusername", ""),
|
||||
struserdes=equipment.get("struserdes", ""),
|
||||
stros=equipment.get("stros", ""),
|
||||
strdomain=equipment.get("strdomain", ""),
|
||||
istatus=dev_detail.get("istatus", equipment.get("istatus", "")),
|
||||
strverofuaagent=dev_detail.get("strverofuaagent", ""),
|
||||
devassetno=dev_detail.get("devassetno", ""),
|
||||
devgroup=dev_detail.get("devgroup", ""),
|
||||
)
|
||||
|
||||
# 解析硬件列表
|
||||
self._parse_hardware_list(dev_detail, "CPUInformation", result.cpu)
|
||||
self._parse_hardware_list(dev_detail, "MemoryInformation", result.memory)
|
||||
self._parse_hardware_list(dev_detail, "HardDiskInformation", result.hard_disk)
|
||||
self._parse_hardware_list(dev_detail, "GraphicsCardInformation", result.graphics_card)
|
||||
self._parse_hardware_list(dev_detail, "MainboardInformation", result.mainboard)
|
||||
|
||||
# 解析逻辑磁盘(含使用率)
|
||||
for ld in dev_detail.get("LogicalDiskInformation", []):
|
||||
result.logical_disk.append(LogicalDiskInfo(
|
||||
name=ld.get("strlogicaldiskname", ""),
|
||||
file_system=ld.get("strfilesystem", ""),
|
||||
total_size=ld.get("strtotalsize", ""),
|
||||
free_space=ld.get("strfreespace", ""),
|
||||
usage_percent=ld.get("strusagepercent", ""),
|
||||
))
|
||||
|
||||
# 解析网卡
|
||||
for nc in dev_detail.get("NetworkCardInformation", []):
|
||||
result.network_card.append(NetworkCardInfo(
|
||||
name=nc.get("strnetcardname", ""),
|
||||
is_wireless=nc.get("iswireless", ""),
|
||||
vendor=nc.get("strnetcardvendor", ""),
|
||||
mac=nc.get("strnetcardmac", ""),
|
||||
))
|
||||
|
||||
# 解析显示器
|
||||
for d in dev_detail.get("DisplayInformation", []):
|
||||
result.display.append(DisplayInfo(
|
||||
vendor=d.get("strdisplayvendor", ""),
|
||||
model=d.get("strdisplaymodel", ""),
|
||||
serial=d.get("strdisplayserial", ""),
|
||||
size=d.get("strdisplaysize", ""),
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def _parse_hardware_list(
|
||||
self, dev_detail: dict, key: str, target_list: list
|
||||
) -> None:
|
||||
"""解析硬件信息列表(CPU/内存/硬盘等)"""
|
||||
from app.integrations.lianruan.models import HardwareInfo
|
||||
|
||||
for item in dev_detail.get(key, []):
|
||||
target_list.append(HardwareInfo(
|
||||
name=item.get("strcpuname", item.get("strmemname", item.get("strdiskname", ""))),
|
||||
model=item.get("strcpumodel", item.get("strmemmodel", item.get("strdiskmodel", ""))),
|
||||
vendor=item.get("strcpuvendor", item.get("strmemvendor", item.get("strdiskvendor", ""))),
|
||||
capacity=item.get("strcpufrequency", item.get("strmemcapacity", item.get("strdiskcapacity", ""))),
|
||||
serial=item.get("strcpuserial", item.get("strmemserial", item.get("strdiskserial", ""))),
|
||||
))
|
||||
|
||||
# ==========================================================================
|
||||
# P0接口 — 组织架构/用户
|
||||
# ==========================================================================
|
||||
|
||||
async def get_user_info_by_account(self, useraccount: str) -> Optional[UserInfo]:
|
||||
"""按账号查询用户信息。
|
||||
|
||||
Args:
|
||||
useraccount: 用户账号
|
||||
|
||||
Returns:
|
||||
UserInfo或None
|
||||
"""
|
||||
data = await self._request(
|
||||
"/querydeptuser",
|
||||
"getUserInfoByAccount",
|
||||
{"useraccount": useraccount},
|
||||
)
|
||||
rows = data.get("rows", data.get("row", []))
|
||||
if rows:
|
||||
row = rows[0] if isinstance(rows, list) else rows
|
||||
return UserInfo(**row)
|
||||
return None
|
||||
|
||||
async def get_all_org_info(self) -> list[OrgDeptInfo]:
|
||||
"""获取全量组织架构(部门+用户)。
|
||||
|
||||
用于定时同步,构建组织架构映射。
|
||||
|
||||
Returns:
|
||||
list[OrgDeptInfo]: 部门列表,每个部门含用户列表
|
||||
"""
|
||||
data = await self._request("/querydeptuser", "getAllOrgInfo")
|
||||
rows = data.get("rows", [])
|
||||
result = []
|
||||
for dept_data in rows:
|
||||
users = []
|
||||
for u in dept_data.get("users", []):
|
||||
users.append(UserInfo(**u))
|
||||
result.append(OrgDeptInfo(
|
||||
deptid=dept_data.get("deptid", ""),
|
||||
deptname=dept_data.get("deptname", ""),
|
||||
parentid=dept_data.get("parentid", ""),
|
||||
users=users,
|
||||
))
|
||||
return result
|
||||
|
||||
# ==========================================================================
|
||||
# P1接口 — 准入控制
|
||||
# ==========================================================================
|
||||
|
||||
async def exist_online_user(
|
||||
self, username: str, strdevip: str = ""
|
||||
) -> OnlineStatus:
|
||||
"""查询终端用户是否在线。
|
||||
|
||||
可精确判断某员工在某IP是否当前在线。
|
||||
|
||||
Args:
|
||||
username: 用户名
|
||||
strdevip: IP地址(可选)
|
||||
|
||||
Returns:
|
||||
OnlineStatus: 在线状态
|
||||
"""
|
||||
params = {"username": username}
|
||||
if strdevip:
|
||||
params["strdevip"] = strdevip
|
||||
|
||||
data = await self._request(
|
||||
"/access/onlineUser", "existOnlineUser", params
|
||||
)
|
||||
is_online = data.get("data", "0") == "1"
|
||||
return OnlineStatus(
|
||||
username=username,
|
||||
ip=strdevip,
|
||||
is_online=is_online,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# P1接口 — 终端操作
|
||||
# ==========================================================================
|
||||
|
||||
async def notice_agent_msg(
|
||||
self, strdevip: str, message: str
|
||||
) -> bool:
|
||||
"""向终端推送弹窗消息。
|
||||
|
||||
Args:
|
||||
strdevip: 终端IP
|
||||
message: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
data = await self._request(
|
||||
"/terminal",
|
||||
"noticeAgentMsg",
|
||||
{"strdevip": strdevip, "msg": message},
|
||||
)
|
||||
return data.get("status") == "SUCCESS"
|
||||
|
||||
async def remote_wake_up(
|
||||
self, strdevip: str, strmac: str
|
||||
) -> bool:
|
||||
"""远程唤醒终端。
|
||||
|
||||
Args:
|
||||
strdevip: 终端IP
|
||||
strmac: 终端MAC地址
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
data = await self._request(
|
||||
"/terminal",
|
||||
"remoteWakeUp",
|
||||
{"strdevip": strdevip, "strmac": strmac},
|
||||
)
|
||||
return data.get("status") == "SUCCESS"
|
||||
|
||||
async def query_software_by_dev(
|
||||
self, strdevname: str = "", strdevip: str = ""
|
||||
) -> Optional[TerminalSoftwareInfo]:
|
||||
"""查询终端安装软件。
|
||||
|
||||
Args:
|
||||
strdevname: 计算机名
|
||||
strdevip: IP地址
|
||||
|
||||
Returns:
|
||||
TerminalSoftwareInfo或None
|
||||
"""
|
||||
params: dict = {}
|
||||
if strdevname:
|
||||
params["strdevname"] = strdevname
|
||||
if strdevip:
|
||||
params["strdevip"] = strdevip
|
||||
|
||||
data = await self._request("/software", "querysoftwarebydev", params)
|
||||
rows = data.get("rows", [])
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
row = rows[0] if isinstance(rows, list) else rows
|
||||
softwares = []
|
||||
for s in row.get("softwares", []):
|
||||
softwares.append(SoftwareInfo(
|
||||
name=s.get("strsoftware", ""),
|
||||
version=s.get("strversion", ""),
|
||||
vendor=s.get("strvendor", ""),
|
||||
install_date=s.get("installdate", ""),
|
||||
))
|
||||
return TerminalSoftwareInfo(
|
||||
strdevname=row.get("strdevname", ""),
|
||||
strdevip=row.get("strdevip", ""),
|
||||
strmac=row.get("strmac", ""),
|
||||
strusername=row.get("strusername", ""),
|
||||
softwares=softwares,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# 测试连接
|
||||
# ==========================================================================
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""测试联软API连接。
|
||||
|
||||
使用getToken接口验证:
|
||||
1. 网络连通性
|
||||
2. IP白名单
|
||||
3. 账号密码正确性
|
||||
4. Token获取成功
|
||||
|
||||
Returns:
|
||||
dict: {"success": bool, "message": str}
|
||||
"""
|
||||
try:
|
||||
token = await self._ensure_token()
|
||||
if token:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "联软API连接成功,Token获取正常",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Token获取失败,返回为空",
|
||||
}
|
||||
except LianruanAuthError as e:
|
||||
return {"success": False, "message": e.message}
|
||||
except LianruanConnectionError as e:
|
||||
return {"success": False, "message": e.message}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"未知错误: {str(e)}"}
|
||||
@@ -0,0 +1,98 @@
|
||||
# 联软LV7000配置管理
|
||||
"""
|
||||
从system_configs表读取联软API配置,构建LianruanClient实例。
|
||||
|
||||
联软配置键(前缀 integration_lianruan_):
|
||||
- integration_lianruan_base_url: 联软API地址(如 http://192.168.x.x:30098)
|
||||
- integration_lianruan_api_account: API账号
|
||||
- integration_lianruan_api_password: API密码
|
||||
- integration_lianruan_validate_key: 验证密钥(可选)
|
||||
|
||||
配置方式:管理后台 → 系统集成 → 联软LV7000 → 填入账号密码
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.integrations.lianruan.client import LianruanClient
|
||||
from app.integrations.lianruan.exceptions import LianruanConfigError
|
||||
from app.models.system_config import SystemConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 联软配置键前缀(与 admin_service INTEGRATION_DEFINITIONS 中的 key_prefix 一致)
|
||||
_PREFIX = "integration_lianruan_"
|
||||
|
||||
|
||||
async def _get_lianruan_config_value(db: AsyncSession, key_suffix: str) -> str:
|
||||
"""读取单个联软配置值。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
key_suffix: 配置键后缀(如 base_url / api_account)
|
||||
|
||||
Returns:
|
||||
str: 配置值,不存在返回空字符串
|
||||
"""
|
||||
full_key = f"{_PREFIX}{key_suffix}"
|
||||
from sqlalchemy import select
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.key == full_key))
|
||||
config_row = result.scalar_one_or_none()
|
||||
return config_row.value if config_row else ""
|
||||
|
||||
|
||||
async def get_lianruan_config(db: AsyncSession) -> dict:
|
||||
"""从system_configs表读取联软配置。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
dict: 包含 base_url / api_account / api_password / validate_key
|
||||
|
||||
Raises:
|
||||
LianruanConfigError: 配置缺失
|
||||
"""
|
||||
base_url = await _get_lianruan_config_value(db, "base_url")
|
||||
api_account = await _get_lianruan_config_value(db, "api_account")
|
||||
api_password = await _get_lianruan_config_value(db, "api_password")
|
||||
validate_key = await _get_lianruan_config_value(db, "validate_key")
|
||||
|
||||
if not base_url:
|
||||
raise LianruanConfigError("联软API未配置:缺少Base URL")
|
||||
if not api_account:
|
||||
raise LianruanConfigError("联软API未配置:缺少API账号")
|
||||
if not api_password:
|
||||
raise LianruanConfigError("联软API未配置:缺少API密码")
|
||||
|
||||
return {
|
||||
"base_url": base_url,
|
||||
"api_account": api_account,
|
||||
"api_password": api_password,
|
||||
"validate_key": validate_key,
|
||||
}
|
||||
|
||||
|
||||
async def get_lianruan_client(db: AsyncSession) -> LianruanClient:
|
||||
"""构建联软API客户端实例。
|
||||
|
||||
从system_configs表读取配置,创建LianruanClient。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
LianruanClient: 已配置的联软客户端
|
||||
|
||||
Raises:
|
||||
LianruanConfigError: 配置缺失
|
||||
"""
|
||||
cfg = await get_lianruan_config(db)
|
||||
|
||||
return LianruanClient(
|
||||
base_url=cfg["base_url"],
|
||||
api_account=cfg["api_account"],
|
||||
api_password=cfg["api_password"],
|
||||
validate_key=cfg.get("validate_key", ""),
|
||||
)
|
||||
@@ -0,0 +1,61 @@
|
||||
# 联软LV7000异常体系
|
||||
"""
|
||||
定义联软API集成的异常类层级。
|
||||
|
||||
层级:
|
||||
LianruanError — 基类(所有联软异常)
|
||||
├── LianruanConfigError — 配置缺失(未填写账号/密码/BaseURL)
|
||||
├── LianruanAuthError — 认证失败(IP不在白名单/账号密码错误/Token过期)
|
||||
├── LianruanConnectionError — 网络连接失败(超时/拒绝连接)
|
||||
└── LianruanApiError — API业务错误(参数错误/数据超限/其他)
|
||||
"""
|
||||
|
||||
|
||||
class LianruanError(Exception):
|
||||
"""联软异常基类"""
|
||||
|
||||
def __init__(self, message: str, detail: str = ""):
|
||||
self.message = message
|
||||
self.detail = detail
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LianruanConfigError(LianruanError):
|
||||
"""配置缺失异常。
|
||||
|
||||
场景:未配置联软 BaseURL / ApiAccount / ApiPassword
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LianruanAuthError(LianruanError):
|
||||
"""认证失败异常。
|
||||
|
||||
场景:
|
||||
- IP不在白名单(status=INVALID)
|
||||
- 账号密码错误
|
||||
- Token过期(需重新获取)
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LianruanConnectionError(LianruanError):
|
||||
"""网络连接失败异常。
|
||||
|
||||
场景:超时/拒绝连接/DNS解析失败
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LianruanApiError(LianruanError):
|
||||
"""API业务错误异常。
|
||||
|
||||
场景:
|
||||
- 参数错误(status=ERROR)
|
||||
- 数据量超限(status=Exceed)
|
||||
- 其他业务异常
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, status: str = "", detail: str = ""):
|
||||
self.status = status # 联软返回的status字段(ERROR/Exceed等)
|
||||
super().__init__(message, detail)
|
||||
@@ -0,0 +1,193 @@
|
||||
# 联软LV7000数据模型
|
||||
"""
|
||||
定义联软API返回数据的Pydantic模型。
|
||||
|
||||
核心模型:
|
||||
- TerminalBasicInfo:终端基本信息(queryDevByParams返回)
|
||||
- TerminalAllInfo:终端详细信息(getDevAllInfo返回,极详细)
|
||||
- UserInfo:用户信息(getUserInfoByAccount返回)
|
||||
- OrgInfo:组织架构信息(getAllOrgInfo返回)
|
||||
- OnlineStatus:终端在线状态(existOnlineUser返回)
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端基本信息(queryDevByParams返回)
|
||||
# ==========================================================================
|
||||
|
||||
class TerminalBasicInfo(BaseModel):
|
||||
"""终端基本信息 — 最核心的映射数据源。
|
||||
|
||||
⭐ strusername + struserdes 字段直接提供员工账号→终端映射!
|
||||
这是联软相比火绒最大的优势。
|
||||
"""
|
||||
# 终端标识
|
||||
strdevname: str = Field(default="", description="计算机名")
|
||||
strdevip: str = Field(default="", description="IP地址")
|
||||
strmac: str = Field(default="", description="MAC地址")
|
||||
|
||||
# ⭐ 员工映射字段(核心价值)
|
||||
strusername: str = Field(default="", description="使用该终端的用户账号(映射金钥匙)")
|
||||
struserdes: str = Field(default="", description="用户姓名/描述")
|
||||
|
||||
# 组织信息
|
||||
strdeptname: str = Field(default="", description="所属部门名")
|
||||
|
||||
# 状态
|
||||
istatus: str = Field(default="", description="终端状态(1=在线/0=离线)")
|
||||
|
||||
# 网络
|
||||
strswitchname: str = Field(default="", description="接入交换机名")
|
||||
strifname: str = Field(default="", description="交换机接口名")
|
||||
|
||||
# 联系方式
|
||||
strmail: str = Field(default="", description="用户邮箱")
|
||||
strphone: str = Field(default="", description="用户电话")
|
||||
|
||||
# 其他
|
||||
strdomain: str = Field(default="", description="Windows域")
|
||||
strdevtype: str = Field(default="", description="设备类型")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端详细信息(getDevAllInfo返回)
|
||||
# ==========================================================================
|
||||
|
||||
class HardwareInfo(BaseModel):
|
||||
"""硬件组件信息"""
|
||||
name: str = Field(default="", description="名称")
|
||||
model: str = Field(default="", description="型号")
|
||||
vendor: str = Field(default="", description="厂商")
|
||||
capacity: str = Field(default="", description="容量")
|
||||
serial: str = Field(default="", description="序列号")
|
||||
|
||||
|
||||
class LogicalDiskInfo(BaseModel):
|
||||
"""逻辑磁盘信息(含使用率,判断磁盘满)"""
|
||||
name: str = Field(default="", description="卷标")
|
||||
file_system: str = Field(default="", description="文件系统")
|
||||
total_size: str = Field(default="", description="总量")
|
||||
free_space: str = Field(default="", description="可用空间")
|
||||
usage_percent: str = Field(default="", description="使用率")
|
||||
|
||||
|
||||
class NetworkCardInfo(BaseModel):
|
||||
"""网卡信息"""
|
||||
name: str = Field(default="", description="名称")
|
||||
is_wireless: str = Field(default="", description="是否无线")
|
||||
vendor: str = Field(default="", description="厂商")
|
||||
mac: str = Field(default="", description="MAC地址")
|
||||
|
||||
|
||||
class DisplayInfo(BaseModel):
|
||||
"""显示器信息(多屏配置排查)"""
|
||||
vendor: str = Field(default="", description="厂商")
|
||||
model: str = Field(default="", description="型号")
|
||||
serial: str = Field(default="", description="序列号")
|
||||
size: str = Field(default="", description="尺寸")
|
||||
|
||||
|
||||
class TerminalAllInfo(BaseModel):
|
||||
"""终端详细信息 — 极其详细,比火绒_info2更丰富。
|
||||
|
||||
包含:设备基础+硬件+软件+资产+网络配置。
|
||||
特别是逻辑磁盘使用率和显示器信息,是火绒没有的。
|
||||
"""
|
||||
# 设备基础
|
||||
strdevname: str = Field(default="", description="计算机名")
|
||||
strip1: str = Field(default="", description="IP地址")
|
||||
strmac: str = Field(default="", description="MAC地址")
|
||||
strnatip: str = Field(default="", description="NAT IP")
|
||||
macverdor: str = Field(default="", description="MAC厂商")
|
||||
strdevtype: str = Field(default="", description="设备类型")
|
||||
|
||||
# 组织+用户
|
||||
strdeptname: str = Field(default="", description="所属部门")
|
||||
strusername: str = Field(default="", description="用户账号⭐")
|
||||
struserdes: str = Field(default="", description="用户姓名⭐")
|
||||
|
||||
# 时间
|
||||
dtdevuptime: str = Field(default="", description="最近上线时间")
|
||||
dtdevdowntime: str = Field(default="", description="最近下线时间")
|
||||
dtdevfirstfoundtime: str = Field(default="", description="首次发现时间")
|
||||
|
||||
# 系统
|
||||
stros: str = Field(default="", description="操作系统")
|
||||
strdomain: str = Field(default="", description="Windows域")
|
||||
strserialnumber: str = Field(default="", description="序列号")
|
||||
strmainboardtype: str = Field(default="", description="主板型号")
|
||||
|
||||
# 客户端详情
|
||||
strverofuaagent: str = Field(default="", description="安全助手版本")
|
||||
istatus: str = Field(default="", description="在线状态")
|
||||
devassetno: str = Field(default="", description="设备资产号")
|
||||
devgroup: str = Field(default="", description="设备所属设备组")
|
||||
|
||||
# 硬件详情(列表)
|
||||
mainboard: list[HardwareInfo] = Field(default_factory=list, description="主板信息")
|
||||
cpu: list[HardwareInfo] = Field(default_factory=list, description="CPU信息")
|
||||
memory: list[HardwareInfo] = Field(default_factory=list, description="内存信息")
|
||||
hard_disk: list[HardwareInfo] = Field(default_factory=list, description="硬盘信息")
|
||||
logical_disk: list[LogicalDiskInfo] = Field(default_factory=list, description="逻辑磁盘")
|
||||
graphics_card: list[HardwareInfo] = Field(default_factory=list, description="显卡信息")
|
||||
network_card: list[NetworkCardInfo] = Field(default_factory=list, description="网卡信息")
|
||||
display: list[DisplayInfo] = Field(default_factory=list, description="显示器信息")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 用户信息(getUserInfoByAccount返回)
|
||||
# ==========================================================================
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""用户信息"""
|
||||
deptid: str = Field(default="", description="部门ID")
|
||||
userid: str = Field(default="", description="用户ID")
|
||||
useraccount: str = Field(default="", description="用户账号")
|
||||
username: str = Field(default="", description="用户姓名")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 组织架构信息(getAllOrgInfo返回)
|
||||
# ==========================================================================
|
||||
|
||||
class OrgDeptInfo(BaseModel):
|
||||
"""部门信息"""
|
||||
deptid: str = Field(default="", description="部门ID")
|
||||
deptname: str = Field(default="", description="部门名称")
|
||||
parentid: str = Field(default="", description="父部门ID")
|
||||
users: list[UserInfo] = Field(default_factory=list, description="部门下用户列表")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 终端在线状态(existOnlineUser返回)
|
||||
# ==========================================================================
|
||||
|
||||
class OnlineStatus(BaseModel):
|
||||
"""终端在线状态"""
|
||||
username: str = Field(default="", description="用户名")
|
||||
ip: str = Field(default="", description="IP地址")
|
||||
is_online: bool = Field(default=False, description="是否在线")
|
||||
|
||||
|
||||
# ==========================================================================
|
||||
# 软件信息(querysoftwarebydev返回)
|
||||
# ==========================================================================
|
||||
|
||||
class SoftwareInfo(BaseModel):
|
||||
"""软件安装信息"""
|
||||
name: str = Field(default="", description="软件名称")
|
||||
version: str = Field(default="", description="版本")
|
||||
vendor: str = Field(default="", description="厂商")
|
||||
install_date: str = Field(default="", description="安装日期")
|
||||
|
||||
|
||||
class TerminalSoftwareInfo(BaseModel):
|
||||
"""终端安装软件信息"""
|
||||
strdevname: str = Field(default="", description="计算机名")
|
||||
strdevip: str = Field(default="", description="IP地址")
|
||||
strmac: str = Field(default="", description="MAC地址")
|
||||
strusername: str = Field(default="", description="用户账号")
|
||||
softwares: list[SoftwareInfo] = Field(default_factory=list, description="软件列表")
|
||||
@@ -0,0 +1,35 @@
|
||||
# =============================================================================
|
||||
# RAGFlow 集成模块
|
||||
# =============================================================================
|
||||
|
||||
from .client import RagflowClient
|
||||
from .config import get_ragflow_client
|
||||
from .exceptions import (
|
||||
RagflowApiError,
|
||||
RagflowAuthError,
|
||||
RagflowConfigError,
|
||||
RagflowConnectionError,
|
||||
RagflowError,
|
||||
)
|
||||
from .models import (
|
||||
DatasetInfo,
|
||||
DocAggregate,
|
||||
DocumentInfo,
|
||||
RetrievalChunk,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RagflowClient",
|
||||
"get_ragflow_client",
|
||||
"RagflowError",
|
||||
"RagflowConfigError",
|
||||
"RagflowAuthError",
|
||||
"RagflowApiError",
|
||||
"RagflowConnectionError",
|
||||
"RetrievalChunk",
|
||||
"DocAggregate",
|
||||
"RetrievalResult",
|
||||
"DatasetInfo",
|
||||
"DocumentInfo",
|
||||
]
|
||||
@@ -0,0 +1,449 @@
|
||||
# =============================================================================
|
||||
# RAGFlow API 客户端
|
||||
# =============================================================================
|
||||
# 说明:封装 RAGFlow 知识检索引擎的 API 调用
|
||||
# 核心功能:
|
||||
# 1. 知识检索 — POST /api/v1/retrieval(核心接口)
|
||||
# 2. 数据集管理 — 列出/创建/删除知识库
|
||||
# 3. 文档管理 — 上传/列出/删除文档
|
||||
# 4. 测试连接 — 验证 API Key 是否有效
|
||||
# 认证方式:Authorization: Bearer <API_KEY>
|
||||
# 参考文档:https://ragflow.io/docs/http_api_reference
|
||||
# =============================================================================
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import (
|
||||
RagflowApiError,
|
||||
RagflowAuthError,
|
||||
RagflowConfigError,
|
||||
RagflowConnectionError,
|
||||
RagflowError,
|
||||
)
|
||||
from .models import (
|
||||
DatasetInfo,
|
||||
DocAggregate,
|
||||
DocumentInfo,
|
||||
RetrievalChunk,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认请求超时(秒)
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
|
||||
# 默认分页大小
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
|
||||
|
||||
class RagflowClient:
|
||||
"""RAGFlow API 客户端。
|
||||
|
||||
封装 RAGFlow 知识检索引擎的 API 调用,支持:
|
||||
- 知识检索(核心功能)
|
||||
- 数据集(知识库)管理
|
||||
- 文档管理
|
||||
- 连接测试
|
||||
|
||||
使用方式:
|
||||
client = RagflowClient(
|
||||
api_key="sk-xxx",
|
||||
base_url="http://10.80.0.85:9380"
|
||||
)
|
||||
result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "http://10.80.0.85:9380",
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""初始化 RAGFlow 客户端。
|
||||
|
||||
Args:
|
||||
api_key: RAGFlow API Key(Bearer Token)
|
||||
base_url: RAGFlow API 基础地址(不含尾部斜杠)
|
||||
timeout: 默认请求超时(秒)
|
||||
|
||||
Raises:
|
||||
RagflowConfigError: API Key 为空
|
||||
"""
|
||||
if not api_key:
|
||||
raise RagflowConfigError("RAGFlow API Key 不能为空")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
"""构建请求头。
|
||||
|
||||
Returns:
|
||||
Dict: 包含 Authorization 和 Content-Type 的请求头
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
json_data: Optional[Dict] = None,
|
||||
params: Optional[Dict] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""统一请求封装。
|
||||
|
||||
Args:
|
||||
method: HTTP 方法(GET/POST/PUT/DELETE)
|
||||
path: API 路径(如 /api/v1/retrieval)
|
||||
json_data: JSON 请求体
|
||||
params: 查询参数
|
||||
timeout: 覆盖默认超时
|
||||
|
||||
Returns:
|
||||
Dict: API 响应的 JSON 数据
|
||||
|
||||
Raises:
|
||||
RagflowAuthError: 认证失败(401)
|
||||
RagflowApiError: API 返回错误
|
||||
RagflowConnectionError: 网络连接失败
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
req_timeout = timeout or self.timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=self._headers(),
|
||||
json=json_data,
|
||||
params=params,
|
||||
timeout=req_timeout,
|
||||
)
|
||||
|
||||
# 处理 HTTP 错误
|
||||
if response.status_code == 401:
|
||||
raise RagflowAuthError("RAGFlow API Key 无效或已过期")
|
||||
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
err_body = response.json()
|
||||
err_msg = err_body.get("message", response.text)
|
||||
except Exception:
|
||||
err_msg = response.text
|
||||
raise RagflowApiError(
|
||||
code=response.status_code,
|
||||
message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}",
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# RAGFlow 统一响应格式:{code: 0, data: ..., message: ...}
|
||||
if result.get("code") != 0:
|
||||
raise RagflowApiError(
|
||||
code=result.get("code", -1),
|
||||
message=result.get("message", "未知错误"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except httpx.TimeoutException:
|
||||
raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}")
|
||||
except httpx.ConnectError:
|
||||
raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}")
|
||||
except (RagflowAuthError, RagflowApiError, RagflowConnectionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RagflowError(f"RAGFlow 请求异常: {str(e)}")
|
||||
|
||||
# ==========================================================================
|
||||
# 测试连接
|
||||
# ==========================================================================
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试 RAGFlow API 连接。
|
||||
|
||||
通过列出数据集(limit=1)验证 API Key 是否有效。
|
||||
|
||||
Returns:
|
||||
Dict: {success: bool, message: str}
|
||||
"""
|
||||
try:
|
||||
result = await self.list_datasets(page=1, page_size=1)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"连接成功,共 {result.get('total', 0)} 个知识库",
|
||||
}
|
||||
except RagflowAuthError:
|
||||
return {"success": False, "message": "API Key 无效或已过期"}
|
||||
except RagflowConnectionError as e:
|
||||
return {"success": False, "message": f"连接失败: {e.message}"}
|
||||
except RagflowError as e:
|
||||
return {"success": False, "message": e.message}
|
||||
|
||||
# ==========================================================================
|
||||
# 知识检索(核心接口)
|
||||
# ==========================================================================
|
||||
|
||||
async def retrieval(
|
||||
self,
|
||||
question: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
document_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.2,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
) -> RetrievalResult:
|
||||
"""知识检索 — 从知识库中搜索相关文档片段。
|
||||
|
||||
这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。
|
||||
|
||||
Args:
|
||||
question: 用户查询问题
|
||||
dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一)
|
||||
document_ids: 要搜索的文档ID列表
|
||||
similarity_threshold: 最小相似度阈值(0-1),默认 0.2
|
||||
vector_similarity_weight: 向量相似度权重(0-1),默认 0.3
|
||||
top_k: 参与计算的块数量,默认 1024
|
||||
keyword: 是否启用关键词匹配,默认 False
|
||||
highlight: 是否高亮匹配术语,默认 False
|
||||
|
||||
Returns:
|
||||
RetrievalResult: 检索结果(含文本块、文档聚合、总数)
|
||||
|
||||
Raises:
|
||||
RagflowError: 检索失败
|
||||
"""
|
||||
body: Dict[str, Any] = {
|
||||
"question": question,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
}
|
||||
|
||||
if dataset_ids:
|
||||
body["dataset_ids"] = dataset_ids
|
||||
if document_ids:
|
||||
body["document_ids"] = document_ids
|
||||
|
||||
result = await self._request("POST", "/api/v1/retrieval", json_data=body)
|
||||
|
||||
data = result.get("data", {})
|
||||
|
||||
# 解析文本块
|
||||
chunks = [
|
||||
RetrievalChunk.model_validate(chunk)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
||||
# 解析文档聚合
|
||||
doc_aggs = [
|
||||
DocAggregate.model_validate(agg)
|
||||
for agg in data.get("doc_aggs", [])
|
||||
]
|
||||
|
||||
return RetrievalResult(
|
||||
chunks=chunks,
|
||||
doc_aggs=doc_aggs,
|
||||
total=data.get("total", 0),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# 数据集(知识库)管理
|
||||
# ==========================================================================
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""列出所有数据集(知识库)。
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页条数
|
||||
|
||||
Returns:
|
||||
Dict: {items: List[DatasetInfo], total: int}
|
||||
"""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
"/api/v1/datasets",
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
items = [
|
||||
DatasetInfo.model_validate(ds)
|
||||
for ds in data.get("datasets", [])
|
||||
]
|
||||
|
||||
return {"items": items, "total": data.get("total", 0)}
|
||||
|
||||
async def create_dataset(
|
||||
self,
|
||||
name: str,
|
||||
embedding_model: str = "BAAI/bge-m3@BAAI",
|
||||
chunk_method: str = "naive",
|
||||
permission: str = "me",
|
||||
) -> DatasetInfo:
|
||||
"""创建数据集(知识库)。
|
||||
|
||||
Args:
|
||||
name: 数据集名称
|
||||
embedding_model: 向量模型
|
||||
chunk_method: 分块方法(naive/qa/book/laws 等)
|
||||
permission: 权限(me/team)
|
||||
|
||||
Returns:
|
||||
DatasetInfo: 创建的数据集信息
|
||||
"""
|
||||
body = {
|
||||
"name": name,
|
||||
"embedding_model": embedding_model,
|
||||
"chunk_method": chunk_method,
|
||||
"permission": permission,
|
||||
}
|
||||
|
||||
result = await self._request("POST", "/api/v1/datasets", json_data=body)
|
||||
return DatasetInfo.model_validate(result.get("data", {}))
|
||||
|
||||
async def delete_dataset(self, dataset_ids: List[str]) -> bool:
|
||||
"""删除数据集。
|
||||
|
||||
Args:
|
||||
dataset_ids: 要删除的数据集ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
"/api/v1/datasets",
|
||||
json_data={"ids": dataset_ids},
|
||||
)
|
||||
return True
|
||||
|
||||
# ==========================================================================
|
||||
# 文档管理
|
||||
# ==========================================================================
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
dataset_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""列出数据集中的文档。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
page: 页码
|
||||
page_size: 每页条数
|
||||
|
||||
Returns:
|
||||
Dict: {items: List[DocumentInfo], total: int}
|
||||
"""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/datasets/{dataset_id}/documents",
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
items = [
|
||||
DocumentInfo.model_validate(doc)
|
||||
for doc in data.get("documents", [])
|
||||
]
|
||||
|
||||
return {"items": items, "total": data.get("total", 0)}
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
file_name: Optional[str] = None,
|
||||
) -> DocumentInfo:
|
||||
"""上传文档到数据集。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
file_path: 本地文件路径
|
||||
file_name: 文件名(可选,默认取 file_path 的文件名)
|
||||
|
||||
Returns:
|
||||
DocumentInfo: 上传的文档信息
|
||||
"""
|
||||
import os
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise RagflowError(f"文件不存在: {file_path}")
|
||||
|
||||
fname = file_name or os.path.basename(file_path)
|
||||
|
||||
url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(file_path, "rb") as f:
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
files={"file": (fname, f)},
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise RagflowAuthError()
|
||||
|
||||
result = response.json()
|
||||
if result.get("code") != 0:
|
||||
raise RagflowApiError(
|
||||
code=result.get("code", -1),
|
||||
message=result.get("message", "上传失败"),
|
||||
)
|
||||
|
||||
docs = result.get("data", {}).get("documents", [])
|
||||
if docs:
|
||||
return DocumentInfo.model_validate(docs[0])
|
||||
return DocumentInfo(name=fname)
|
||||
|
||||
except (RagflowAuthError, RagflowApiError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RagflowError(f"文档上传失败: {str(e)}")
|
||||
|
||||
async def delete_documents(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: List[str],
|
||||
) -> bool:
|
||||
"""删除文档。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
document_ids: 要删除的文档ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/api/v1/datasets/{dataset_id}/documents",
|
||||
json_data={"ids": document_ids},
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,61 @@
|
||||
# =============================================================================
|
||||
# RAGFlow 配置加载器
|
||||
# =============================================================================
|
||||
# 说明:从数据库 system_configs 表加载 RAGFlow 配置,创建客户端实例
|
||||
# 配置项:integration_ragflow_api_url + integration_ragflow_api_key
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.system_config import SystemConfig
|
||||
|
||||
from .client import RagflowClient
|
||||
from .exceptions import RagflowConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认 RAGFlow API 地址(生产环境)
|
||||
DEFAULT_RAGFLOW_BASE_URL = "http://10.80.0.85:9380"
|
||||
|
||||
|
||||
async def _get_config(db: AsyncSession, key: str) -> str:
|
||||
"""从数据库读取单个配置值。"""
|
||||
result = await db.execute(
|
||||
select(SystemConfig.config_value).where(SystemConfig.config_key == key)
|
||||
)
|
||||
row = result.scalar()
|
||||
return row if row else ""
|
||||
|
||||
|
||||
async def get_ragflow_client(db: AsyncSession) -> RagflowClient:
|
||||
"""从数据库配置创建 RAGFlow 客户端实例。
|
||||
|
||||
读取 system_configs 表中的:
|
||||
- integration_ragflow_api_url: RAGFlow API 地址
|
||||
- integration_ragflow_api_key: RAGFlow API Key
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
RagflowClient: 客户端实例
|
||||
|
||||
Raises:
|
||||
RagflowConfigError: 配置缺失
|
||||
"""
|
||||
api_url = await _get_config(db, "integration_ragflow_api_url")
|
||||
api_key = await _get_config(db, "integration_ragflow_api_key")
|
||||
|
||||
# 如果数据库没有配置,使用默认地址
|
||||
if not api_url:
|
||||
api_url = DEFAULT_RAGFLOW_BASE_URL
|
||||
|
||||
if not api_key:
|
||||
raise RagflowConfigError(
|
||||
"RAGFlow API Key 未配置,请在管理后台 → 集成管理 → RAGFlow 中设置"
|
||||
)
|
||||
|
||||
return RagflowClient(api_key=api_key, base_url=api_url)
|
||||
@@ -0,0 +1,35 @@
|
||||
# =============================================================================
|
||||
# RAGFlow API 异常定义
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class RagflowError(Exception):
|
||||
"""RAGFlow 基础异常。"""
|
||||
def __init__(self, message: str = "RAGFlow 错误"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class RagflowConfigError(RagflowError):
|
||||
"""配置错误(缺少 API Key 或 Base URL)。"""
|
||||
def __init__(self, message: str = "RAGFlow 配置缺失"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RagflowAuthError(RagflowError):
|
||||
"""认证失败(API Key 无效)。"""
|
||||
def __init__(self, message: str = "RAGFlow 认证失败"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RagflowApiError(RagflowError):
|
||||
"""API 调用失败(非 200 响应)。"""
|
||||
def __init__(self, code: int = 0, message: str = "RAGFlow API 错误"):
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class RagflowConnectionError(RagflowError):
|
||||
"""网络连接失败。"""
|
||||
def __init__(self, message: str = "RAGFlow 连接失败"):
|
||||
super().__init__(message)
|
||||
@@ -0,0 +1,110 @@
|
||||
# =============================================================================
|
||||
# RAGFlow API 数据模型
|
||||
# =============================================================================
|
||||
# 说明:定义 RAGFlow API 请求/响应的 Pydantic 数据模型
|
||||
# 参考:https://ragflow.io/docs/http_api_reference
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RetrievalChunk(BaseModel):
|
||||
"""检索返回的单个文本块。
|
||||
|
||||
Attributes:
|
||||
id: 块唯一ID
|
||||
content: 块内容文本
|
||||
document_id: 所属文档ID
|
||||
document_keyword: 所属文档名称
|
||||
similarity: 综合相似度分数
|
||||
term_similarity: 关键词相似度
|
||||
vector_similarity: 向量相似度
|
||||
highlight: 高亮标记的内容(可选)
|
||||
"""
|
||||
id: str = Field(default="", description="块唯一ID")
|
||||
content: str = Field(default="", description="块内容文本")
|
||||
document_id: str = Field(default="", description="所属文档ID")
|
||||
document_keyword: str = Field(default="", description="所属文档名称")
|
||||
similarity: float = Field(default=0.0, description="综合相似度分数")
|
||||
term_similarity: float = Field(default=0.0, description="关键词相似度")
|
||||
vector_similarity: float = Field(default=0.0, description="向量相似度")
|
||||
highlight: Optional[str] = Field(default=None, description="高亮标记的内容")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DocAggregate(BaseModel):
|
||||
"""文档聚合统计。
|
||||
|
||||
Attributes:
|
||||
doc_id: 文档ID
|
||||
doc_name: 文档名称
|
||||
count: 命中的块数量
|
||||
"""
|
||||
doc_id: str = Field(default="", description="文档ID")
|
||||
doc_name: str = Field(default="", description="文档名称")
|
||||
count: int = Field(default=0, description="命中块数量")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class RetrievalResult(BaseModel):
|
||||
"""检索结果。
|
||||
|
||||
Attributes:
|
||||
chunks: 命中的文本块列表
|
||||
doc_aggs: 按文档聚合统计
|
||||
total: 命中总数
|
||||
"""
|
||||
chunks: List[RetrievalChunk] = Field(default_factory=list, description="命中文本块列表")
|
||||
doc_aggs: List[DocAggregate] = Field(default_factory=list, description="文档聚合统计")
|
||||
total: int = Field(default=0, description="命中总数")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
"""数据集(知识库)信息。
|
||||
|
||||
Attributes:
|
||||
id: 数据集ID
|
||||
name: 数据集名称
|
||||
chunk_method: 分块方法
|
||||
permission: 权限
|
||||
document_count: 文档数量
|
||||
embedding_model: 向量模型
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
"""
|
||||
id: str = Field(default="", description="数据集ID")
|
||||
name: str = Field(default="", description="数据集名称")
|
||||
chunk_method: str = Field(default="naive", description="分块方法")
|
||||
permission: str = Field(default="me", description="权限")
|
||||
document_count: int = Field(default=0, description="文档数量")
|
||||
embedding_model: str = Field(default="", description="向量模型")
|
||||
create_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
update_time: Optional[str] = Field(default=None, description="更新时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class DocumentInfo(BaseModel):
|
||||
"""文档信息。
|
||||
|
||||
Attributes:
|
||||
id: 文档ID
|
||||
name: 文档名称
|
||||
chunk_method: 分块方法
|
||||
chunk_count: 块数量
|
||||
create_time: 创建时间
|
||||
update_time: 更新时间
|
||||
"""
|
||||
id: str = Field(default="", description="文档ID")
|
||||
name: str = Field(default="", description="文档名称")
|
||||
chunk_method: str = Field(default="naive", description="分块方法")
|
||||
chunk_count: int = Field(default=0, description="块数量")
|
||||
create_time: Optional[str] = Field(default=None, description="创建时间")
|
||||
update_time: Optional[str] = Field(default=None, description="更新时间")
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
Reference in New Issue
Block a user