Files
wecom_it_smart_desk/backend/app/integrations/ragflow/client.py
T

450 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# =============================================================================
# RAGFlow API 客户端
# =============================================================================
# 说明:封装 RAGFlow 知识检索引擎的 API 调用
# 核心功能:
# 1. 知识检索 — POST /api/v1/retrieval(核心接口)
# 2. 数据集管理 — 列出/创建/删除知识库
# 3. 文档管理 — 上传/列出/删除文档
# 4. 测试连接 — 验证 API Key 是否有效
# 认证方式:Authorization: Bearer <API_KEY>
# 参考文档:https://ragflow.io/docs/http_api_reference
# =============================================================================
import logging
from typing import Any, Dict, List, Optional
import httpx
from .exceptions import (
RagflowApiError,
RagflowAuthError,
RagflowConfigError,
RagflowConnectionError,
RagflowError,
)
from .models import (
DatasetInfo,
DocAggregate,
DocumentInfo,
RetrievalChunk,
RetrievalResult,
)
logger = logging.getLogger(__name__)
# 默认请求超时(秒)
DEFAULT_TIMEOUT = 30.0
# 默认分页大小
DEFAULT_PAGE_SIZE = 20
class RagflowClient:
"""RAGFlow API 客户端。
封装 RAGFlow 知识检索引擎的 API 调用,支持:
- 知识检索(核心功能)
- 数据集(知识库)管理
- 文档管理
- 连接测试
使用方式:
client = RagflowClient(
api_key="sk-xxx",
base_url="http://10.80.0.85:9380"
)
result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"])
"""
def __init__(
self,
api_key: str,
base_url: str = "http://10.80.0.85:9380",
timeout: float = DEFAULT_TIMEOUT,
):
"""初始化 RAGFlow 客户端。
Args:
api_key: RAGFlow API KeyBearer Token
base_url: RAGFlow API 基础地址(不含尾部斜杠)
timeout: 默认请求超时(秒)
Raises:
RagflowConfigError: API Key 为空
"""
if not api_key:
raise RagflowConfigError("RAGFlow API Key 不能为空")
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.timeout = timeout
def _headers(self) -> Dict[str, str]:
"""构建请求头。
Returns:
Dict: 包含 Authorization 和 Content-Type 的请求头
"""
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
async def _request(
self,
method: str,
path: str,
json_data: Optional[Dict] = None,
params: Optional[Dict] = None,
timeout: Optional[float] = None,
) -> Dict[str, Any]:
"""统一请求封装。
Args:
method: HTTP 方法(GET/POST/PUT/DELETE
path: API 路径(如 /api/v1/retrieval
json_data: JSON 请求体
params: 查询参数
timeout: 覆盖默认超时
Returns:
Dict: API 响应的 JSON 数据
Raises:
RagflowAuthError: 认证失败(401
RagflowApiError: API 返回错误
RagflowConnectionError: 网络连接失败
"""
url = f"{self.base_url}{path}"
req_timeout = timeout or self.timeout
try:
async with httpx.AsyncClient() as client:
response = await client.request(
method=method,
url=url,
headers=self._headers(),
json=json_data,
params=params,
timeout=req_timeout,
)
# 处理 HTTP 错误
if response.status_code == 401:
raise RagflowAuthError("RAGFlow API Key 无效或已过期")
if response.status_code >= 400:
try:
err_body = response.json()
err_msg = err_body.get("message", response.text)
except Exception:
err_msg = response.text
raise RagflowApiError(
code=response.status_code,
message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}",
)
# 解析响应
result = response.json()
# RAGFlow 统一响应格式:{code: 0, data: ..., message: ...}
if result.get("code") != 0:
raise RagflowApiError(
code=result.get("code", -1),
message=result.get("message", "未知错误"),
)
return result
except httpx.TimeoutException:
raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}")
except httpx.ConnectError:
raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}")
except (RagflowAuthError, RagflowApiError, RagflowConnectionError):
raise
except Exception as e:
raise RagflowError(f"RAGFlow 请求异常: {str(e)}")
# ==========================================================================
# 测试连接
# ==========================================================================
async def test_connection(self) -> Dict[str, Any]:
"""测试 RAGFlow API 连接。
通过列出数据集(limit=1)验证 API Key 是否有效。
Returns:
Dict: {success: bool, message: str}
"""
try:
result = await self.list_datasets(page=1, page_size=1)
return {
"success": True,
"message": f"连接成功,共 {result.get('total', 0)} 个知识库",
}
except RagflowAuthError:
return {"success": False, "message": "API Key 无效或已过期"}
except RagflowConnectionError as e:
return {"success": False, "message": f"连接失败: {e.message}"}
except RagflowError as e:
return {"success": False, "message": e.message}
# ==========================================================================
# 知识检索(核心接口)
# ==========================================================================
async def retrieval(
self,
question: str,
dataset_ids: Optional[List[str]] = None,
document_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.2,
vector_similarity_weight: float = 0.3,
top_k: int = 1024,
keyword: bool = False,
highlight: bool = False,
) -> RetrievalResult:
"""知识检索 — 从知识库中搜索相关文档片段。
这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。
Args:
question: 用户查询问题
dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一)
document_ids: 要搜索的文档ID列表
similarity_threshold: 最小相似度阈值(0-1),默认 0.2
vector_similarity_weight: 向量相似度权重(0-1),默认 0.3
top_k: 参与计算的块数量,默认 1024
keyword: 是否启用关键词匹配,默认 False
highlight: 是否高亮匹配术语,默认 False
Returns:
RetrievalResult: 检索结果(含文本块、文档聚合、总数)
Raises:
RagflowError: 检索失败
"""
body: Dict[str, Any] = {
"question": question,
"similarity_threshold": similarity_threshold,
"vector_similarity_weight": vector_similarity_weight,
"top_k": top_k,
"keyword": keyword,
"highlight": highlight,
}
if dataset_ids:
body["dataset_ids"] = dataset_ids
if document_ids:
body["document_ids"] = document_ids
result = await self._request("POST", "/api/v1/retrieval", json_data=body)
data = result.get("data", {})
# 解析文本块
chunks = [
RetrievalChunk.model_validate(chunk)
for chunk in data.get("chunks", [])
]
# 解析文档聚合
doc_aggs = [
DocAggregate.model_validate(agg)
for agg in data.get("doc_aggs", [])
]
return RetrievalResult(
chunks=chunks,
doc_aggs=doc_aggs,
total=data.get("total", 0),
)
# ==========================================================================
# 数据集(知识库)管理
# ==========================================================================
async def list_datasets(
self,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""列出所有数据集(知识库)。
Args:
page: 页码
page_size: 每页条数
Returns:
Dict: {items: List[DatasetInfo], total: int}
"""
result = await self._request(
"GET",
"/api/v1/datasets",
params={"page": page, "page_size": page_size},
)
data = result.get("data", {})
items = [
DatasetInfo.model_validate(ds)
for ds in data.get("datasets", [])
]
return {"items": items, "total": data.get("total", 0)}
async def create_dataset(
self,
name: str,
embedding_model: str = "BAAI/bge-m3@BAAI",
chunk_method: str = "naive",
permission: str = "me",
) -> DatasetInfo:
"""创建数据集(知识库)。
Args:
name: 数据集名称
embedding_model: 向量模型
chunk_method: 分块方法(naive/qa/book/laws 等)
permission: 权限(me/team
Returns:
DatasetInfo: 创建的数据集信息
"""
body = {
"name": name,
"embedding_model": embedding_model,
"chunk_method": chunk_method,
"permission": permission,
}
result = await self._request("POST", "/api/v1/datasets", json_data=body)
return DatasetInfo.model_validate(result.get("data", {}))
async def delete_dataset(self, dataset_ids: List[str]) -> bool:
"""删除数据集。
Args:
dataset_ids: 要删除的数据集ID列表
Returns:
bool: 是否成功
"""
await self._request(
"DELETE",
"/api/v1/datasets",
json_data={"ids": dataset_ids},
)
return True
# ==========================================================================
# 文档管理
# ==========================================================================
async def list_documents(
self,
dataset_id: str,
page: int = 1,
page_size: int = DEFAULT_PAGE_SIZE,
) -> Dict[str, Any]:
"""列出数据集中的文档。
Args:
dataset_id: 数据集ID
page: 页码
page_size: 每页条数
Returns:
Dict: {items: List[DocumentInfo], total: int}
"""
result = await self._request(
"GET",
f"/api/v1/datasets/{dataset_id}/documents",
params={"page": page, "page_size": page_size},
)
data = result.get("data", {})
items = [
DocumentInfo.model_validate(doc)
for doc in data.get("documents", [])
]
return {"items": items, "total": data.get("total", 0)}
async def upload_document(
self,
dataset_id: str,
file_path: str,
file_name: Optional[str] = None,
) -> DocumentInfo:
"""上传文档到数据集。
Args:
dataset_id: 数据集ID
file_path: 本地文件路径
file_name: 文件名(可选,默认取 file_path 的文件名)
Returns:
DocumentInfo: 上传的文档信息
"""
import os
if not os.path.exists(file_path):
raise RagflowError(f"文件不存在: {file_path}")
fname = file_name or os.path.basename(file_path)
url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents"
try:
async with httpx.AsyncClient() as client:
with open(file_path, "rb") as f:
response = await client.post(
url=url,
headers={"Authorization": f"Bearer {self.api_key}"},
files={"file": (fname, f)},
timeout=60.0,
)
if response.status_code == 401:
raise RagflowAuthError()
result = response.json()
if result.get("code") != 0:
raise RagflowApiError(
code=result.get("code", -1),
message=result.get("message", "上传失败"),
)
docs = result.get("data", {}).get("documents", [])
if docs:
return DocumentInfo.model_validate(docs[0])
return DocumentInfo(name=fname)
except (RagflowAuthError, RagflowApiError):
raise
except Exception as e:
raise RagflowError(f"文档上传失败: {str(e)}")
async def delete_documents(
self,
dataset_id: str,
document_ids: List[str],
) -> bool:
"""删除文档。
Args:
dataset_id: 数据集ID
document_ids: 要删除的文档ID列表
Returns:
bool: 是否成功
"""
await self._request(
"DELETE",
f"/api/v1/datasets/{dataset_id}/documents",
json_data={"ids": document_ids},
)
return True