450 lines
14 KiB
Python
450 lines
14 KiB
Python
# =============================================================================
|
||
# RAGFlow API 客户端
|
||
# =============================================================================
|
||
# 说明:封装 RAGFlow 知识检索引擎的 API 调用
|
||
# 核心功能:
|
||
# 1. 知识检索 — POST /api/v1/retrieval(核心接口)
|
||
# 2. 数据集管理 — 列出/创建/删除知识库
|
||
# 3. 文档管理 — 上传/列出/删除文档
|
||
# 4. 测试连接 — 验证 API Key 是否有效
|
||
# 认证方式:Authorization: Bearer <API_KEY>
|
||
# 参考文档:https://ragflow.io/docs/http_api_reference
|
||
# =============================================================================
|
||
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import httpx
|
||
|
||
from .exceptions import (
|
||
RagflowApiError,
|
||
RagflowAuthError,
|
||
RagflowConfigError,
|
||
RagflowConnectionError,
|
||
RagflowError,
|
||
)
|
||
from .models import (
|
||
DatasetInfo,
|
||
DocAggregate,
|
||
DocumentInfo,
|
||
RetrievalChunk,
|
||
RetrievalResult,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 默认请求超时(秒)
|
||
DEFAULT_TIMEOUT = 30.0
|
||
|
||
# 默认分页大小
|
||
DEFAULT_PAGE_SIZE = 20
|
||
|
||
|
||
class RagflowClient:
|
||
"""RAGFlow API 客户端。
|
||
|
||
封装 RAGFlow 知识检索引擎的 API 调用,支持:
|
||
- 知识检索(核心功能)
|
||
- 数据集(知识库)管理
|
||
- 文档管理
|
||
- 连接测试
|
||
|
||
使用方式:
|
||
client = RagflowClient(
|
||
api_key="sk-xxx",
|
||
base_url="http://10.80.0.85:9380"
|
||
)
|
||
result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"])
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
base_url: str = "http://10.80.0.85:9380",
|
||
timeout: float = DEFAULT_TIMEOUT,
|
||
):
|
||
"""初始化 RAGFlow 客户端。
|
||
|
||
Args:
|
||
api_key: RAGFlow API Key(Bearer Token)
|
||
base_url: RAGFlow API 基础地址(不含尾部斜杠)
|
||
timeout: 默认请求超时(秒)
|
||
|
||
Raises:
|
||
RagflowConfigError: API Key 为空
|
||
"""
|
||
if not api_key:
|
||
raise RagflowConfigError("RAGFlow API Key 不能为空")
|
||
|
||
self.api_key = api_key
|
||
self.base_url = base_url.rstrip("/")
|
||
self.timeout = timeout
|
||
|
||
def _headers(self) -> Dict[str, str]:
|
||
"""构建请求头。
|
||
|
||
Returns:
|
||
Dict: 包含 Authorization 和 Content-Type 的请求头
|
||
"""
|
||
return {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
async def _request(
|
||
self,
|
||
method: str,
|
||
path: str,
|
||
json_data: Optional[Dict] = None,
|
||
params: Optional[Dict] = None,
|
||
timeout: Optional[float] = None,
|
||
) -> Dict[str, Any]:
|
||
"""统一请求封装。
|
||
|
||
Args:
|
||
method: HTTP 方法(GET/POST/PUT/DELETE)
|
||
path: API 路径(如 /api/v1/retrieval)
|
||
json_data: JSON 请求体
|
||
params: 查询参数
|
||
timeout: 覆盖默认超时
|
||
|
||
Returns:
|
||
Dict: API 响应的 JSON 数据
|
||
|
||
Raises:
|
||
RagflowAuthError: 认证失败(401)
|
||
RagflowApiError: API 返回错误
|
||
RagflowConnectionError: 网络连接失败
|
||
"""
|
||
url = f"{self.base_url}{path}"
|
||
req_timeout = timeout or self.timeout
|
||
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.request(
|
||
method=method,
|
||
url=url,
|
||
headers=self._headers(),
|
||
json=json_data,
|
||
params=params,
|
||
timeout=req_timeout,
|
||
)
|
||
|
||
# 处理 HTTP 错误
|
||
if response.status_code == 401:
|
||
raise RagflowAuthError("RAGFlow API Key 无效或已过期")
|
||
|
||
if response.status_code >= 400:
|
||
try:
|
||
err_body = response.json()
|
||
err_msg = err_body.get("message", response.text)
|
||
except Exception:
|
||
err_msg = response.text
|
||
raise RagflowApiError(
|
||
code=response.status_code,
|
||
message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}",
|
||
)
|
||
|
||
# 解析响应
|
||
result = response.json()
|
||
|
||
# RAGFlow 统一响应格式:{code: 0, data: ..., message: ...}
|
||
if result.get("code") != 0:
|
||
raise RagflowApiError(
|
||
code=result.get("code", -1),
|
||
message=result.get("message", "未知错误"),
|
||
)
|
||
|
||
return result
|
||
|
||
except httpx.TimeoutException:
|
||
raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}")
|
||
except httpx.ConnectError:
|
||
raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}")
|
||
except (RagflowAuthError, RagflowApiError, RagflowConnectionError):
|
||
raise
|
||
except Exception as e:
|
||
raise RagflowError(f"RAGFlow 请求异常: {str(e)}")
|
||
|
||
# ==========================================================================
|
||
# 测试连接
|
||
# ==========================================================================
|
||
|
||
async def test_connection(self) -> Dict[str, Any]:
|
||
"""测试 RAGFlow API 连接。
|
||
|
||
通过列出数据集(limit=1)验证 API Key 是否有效。
|
||
|
||
Returns:
|
||
Dict: {success: bool, message: str}
|
||
"""
|
||
try:
|
||
result = await self.list_datasets(page=1, page_size=1)
|
||
return {
|
||
"success": True,
|
||
"message": f"连接成功,共 {result.get('total', 0)} 个知识库",
|
||
}
|
||
except RagflowAuthError:
|
||
return {"success": False, "message": "API Key 无效或已过期"}
|
||
except RagflowConnectionError as e:
|
||
return {"success": False, "message": f"连接失败: {e.message}"}
|
||
except RagflowError as e:
|
||
return {"success": False, "message": e.message}
|
||
|
||
# ==========================================================================
|
||
# 知识检索(核心接口)
|
||
# ==========================================================================
|
||
|
||
async def retrieval(
|
||
self,
|
||
question: str,
|
||
dataset_ids: Optional[List[str]] = None,
|
||
document_ids: Optional[List[str]] = None,
|
||
similarity_threshold: float = 0.2,
|
||
vector_similarity_weight: float = 0.3,
|
||
top_k: int = 1024,
|
||
keyword: bool = False,
|
||
highlight: bool = False,
|
||
) -> RetrievalResult:
|
||
"""知识检索 — 从知识库中搜索相关文档片段。
|
||
|
||
这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。
|
||
|
||
Args:
|
||
question: 用户查询问题
|
||
dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一)
|
||
document_ids: 要搜索的文档ID列表
|
||
similarity_threshold: 最小相似度阈值(0-1),默认 0.2
|
||
vector_similarity_weight: 向量相似度权重(0-1),默认 0.3
|
||
top_k: 参与计算的块数量,默认 1024
|
||
keyword: 是否启用关键词匹配,默认 False
|
||
highlight: 是否高亮匹配术语,默认 False
|
||
|
||
Returns:
|
||
RetrievalResult: 检索结果(含文本块、文档聚合、总数)
|
||
|
||
Raises:
|
||
RagflowError: 检索失败
|
||
"""
|
||
body: Dict[str, Any] = {
|
||
"question": question,
|
||
"similarity_threshold": similarity_threshold,
|
||
"vector_similarity_weight": vector_similarity_weight,
|
||
"top_k": top_k,
|
||
"keyword": keyword,
|
||
"highlight": highlight,
|
||
}
|
||
|
||
if dataset_ids:
|
||
body["dataset_ids"] = dataset_ids
|
||
if document_ids:
|
||
body["document_ids"] = document_ids
|
||
|
||
result = await self._request("POST", "/api/v1/retrieval", json_data=body)
|
||
|
||
data = result.get("data", {})
|
||
|
||
# 解析文本块
|
||
chunks = [
|
||
RetrievalChunk.model_validate(chunk)
|
||
for chunk in data.get("chunks", [])
|
||
]
|
||
|
||
# 解析文档聚合
|
||
doc_aggs = [
|
||
DocAggregate.model_validate(agg)
|
||
for agg in data.get("doc_aggs", [])
|
||
]
|
||
|
||
return RetrievalResult(
|
||
chunks=chunks,
|
||
doc_aggs=doc_aggs,
|
||
total=data.get("total", 0),
|
||
)
|
||
|
||
# ==========================================================================
|
||
# 数据集(知识库)管理
|
||
# ==========================================================================
|
||
|
||
async def list_datasets(
|
||
self,
|
||
page: int = 1,
|
||
page_size: int = DEFAULT_PAGE_SIZE,
|
||
) -> Dict[str, Any]:
|
||
"""列出所有数据集(知识库)。
|
||
|
||
Args:
|
||
page: 页码
|
||
page_size: 每页条数
|
||
|
||
Returns:
|
||
Dict: {items: List[DatasetInfo], total: int}
|
||
"""
|
||
result = await self._request(
|
||
"GET",
|
||
"/api/v1/datasets",
|
||
params={"page": page, "page_size": page_size},
|
||
)
|
||
|
||
data = result.get("data", {})
|
||
items = [
|
||
DatasetInfo.model_validate(ds)
|
||
for ds in data.get("datasets", [])
|
||
]
|
||
|
||
return {"items": items, "total": data.get("total", 0)}
|
||
|
||
async def create_dataset(
|
||
self,
|
||
name: str,
|
||
embedding_model: str = "BAAI/bge-m3@BAAI",
|
||
chunk_method: str = "naive",
|
||
permission: str = "me",
|
||
) -> DatasetInfo:
|
||
"""创建数据集(知识库)。
|
||
|
||
Args:
|
||
name: 数据集名称
|
||
embedding_model: 向量模型
|
||
chunk_method: 分块方法(naive/qa/book/laws 等)
|
||
permission: 权限(me/team)
|
||
|
||
Returns:
|
||
DatasetInfo: 创建的数据集信息
|
||
"""
|
||
body = {
|
||
"name": name,
|
||
"embedding_model": embedding_model,
|
||
"chunk_method": chunk_method,
|
||
"permission": permission,
|
||
}
|
||
|
||
result = await self._request("POST", "/api/v1/datasets", json_data=body)
|
||
return DatasetInfo.model_validate(result.get("data", {}))
|
||
|
||
async def delete_dataset(self, dataset_ids: List[str]) -> bool:
|
||
"""删除数据集。
|
||
|
||
Args:
|
||
dataset_ids: 要删除的数据集ID列表
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
await self._request(
|
||
"DELETE",
|
||
"/api/v1/datasets",
|
||
json_data={"ids": dataset_ids},
|
||
)
|
||
return True
|
||
|
||
# ==========================================================================
|
||
# 文档管理
|
||
# ==========================================================================
|
||
|
||
async def list_documents(
|
||
self,
|
||
dataset_id: str,
|
||
page: int = 1,
|
||
page_size: int = DEFAULT_PAGE_SIZE,
|
||
) -> Dict[str, Any]:
|
||
"""列出数据集中的文档。
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
page: 页码
|
||
page_size: 每页条数
|
||
|
||
Returns:
|
||
Dict: {items: List[DocumentInfo], total: int}
|
||
"""
|
||
result = await self._request(
|
||
"GET",
|
||
f"/api/v1/datasets/{dataset_id}/documents",
|
||
params={"page": page, "page_size": page_size},
|
||
)
|
||
|
||
data = result.get("data", {})
|
||
items = [
|
||
DocumentInfo.model_validate(doc)
|
||
for doc in data.get("documents", [])
|
||
]
|
||
|
||
return {"items": items, "total": data.get("total", 0)}
|
||
|
||
async def upload_document(
|
||
self,
|
||
dataset_id: str,
|
||
file_path: str,
|
||
file_name: Optional[str] = None,
|
||
) -> DocumentInfo:
|
||
"""上传文档到数据集。
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
file_path: 本地文件路径
|
||
file_name: 文件名(可选,默认取 file_path 的文件名)
|
||
|
||
Returns:
|
||
DocumentInfo: 上传的文档信息
|
||
"""
|
||
import os
|
||
|
||
if not os.path.exists(file_path):
|
||
raise RagflowError(f"文件不存在: {file_path}")
|
||
|
||
fname = file_name or os.path.basename(file_path)
|
||
|
||
url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents"
|
||
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
with open(file_path, "rb") as f:
|
||
response = await client.post(
|
||
url=url,
|
||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||
files={"file": (fname, f)},
|
||
timeout=60.0,
|
||
)
|
||
|
||
if response.status_code == 401:
|
||
raise RagflowAuthError()
|
||
|
||
result = response.json()
|
||
if result.get("code") != 0:
|
||
raise RagflowApiError(
|
||
code=result.get("code", -1),
|
||
message=result.get("message", "上传失败"),
|
||
)
|
||
|
||
docs = result.get("data", {}).get("documents", [])
|
||
if docs:
|
||
return DocumentInfo.model_validate(docs[0])
|
||
return DocumentInfo(name=fname)
|
||
|
||
except (RagflowAuthError, RagflowApiError):
|
||
raise
|
||
except Exception as e:
|
||
raise RagflowError(f"文档上传失败: {str(e)}")
|
||
|
||
async def delete_documents(
|
||
self,
|
||
dataset_id: str,
|
||
document_ids: List[str],
|
||
) -> bool:
|
||
"""删除文档。
|
||
|
||
Args:
|
||
dataset_id: 数据集ID
|
||
document_ids: 要删除的文档ID列表
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
await self._request(
|
||
"DELETE",
|
||
f"/api/v1/datasets/{dataset_id}/documents",
|
||
json_data={"ids": document_ids},
|
||
)
|
||
return True
|