chore: initial baseline with P0-safety .gitignore
This commit is contained in:
@@ -0,0 +1,449 @@
|
||||
# =============================================================================
|
||||
# RAGFlow API 客户端
|
||||
# =============================================================================
|
||||
# 说明:封装 RAGFlow 知识检索引擎的 API 调用
|
||||
# 核心功能:
|
||||
# 1. 知识检索 — POST /api/v1/retrieval(核心接口)
|
||||
# 2. 数据集管理 — 列出/创建/删除知识库
|
||||
# 3. 文档管理 — 上传/列出/删除文档
|
||||
# 4. 测试连接 — 验证 API Key 是否有效
|
||||
# 认证方式:Authorization: Bearer <API_KEY>
|
||||
# 参考文档:https://ragflow.io/docs/http_api_reference
|
||||
# =============================================================================
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .exceptions import (
|
||||
RagflowApiError,
|
||||
RagflowAuthError,
|
||||
RagflowConfigError,
|
||||
RagflowConnectionError,
|
||||
RagflowError,
|
||||
)
|
||||
from .models import (
|
||||
DatasetInfo,
|
||||
DocAggregate,
|
||||
DocumentInfo,
|
||||
RetrievalChunk,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认请求超时(秒)
|
||||
DEFAULT_TIMEOUT = 30.0
|
||||
|
||||
# 默认分页大小
|
||||
DEFAULT_PAGE_SIZE = 20
|
||||
|
||||
|
||||
class RagflowClient:
|
||||
"""RAGFlow API 客户端。
|
||||
|
||||
封装 RAGFlow 知识检索引擎的 API 调用,支持:
|
||||
- 知识检索(核心功能)
|
||||
- 数据集(知识库)管理
|
||||
- 文档管理
|
||||
- 连接测试
|
||||
|
||||
使用方式:
|
||||
client = RagflowClient(
|
||||
api_key="sk-xxx",
|
||||
base_url="http://10.80.0.85:9380"
|
||||
)
|
||||
result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "http://10.80.0.85:9380",
|
||||
timeout: float = DEFAULT_TIMEOUT,
|
||||
):
|
||||
"""初始化 RAGFlow 客户端。
|
||||
|
||||
Args:
|
||||
api_key: RAGFlow API Key(Bearer Token)
|
||||
base_url: RAGFlow API 基础地址(不含尾部斜杠)
|
||||
timeout: 默认请求超时(秒)
|
||||
|
||||
Raises:
|
||||
RagflowConfigError: API Key 为空
|
||||
"""
|
||||
if not api_key:
|
||||
raise RagflowConfigError("RAGFlow API Key 不能为空")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
"""构建请求头。
|
||||
|
||||
Returns:
|
||||
Dict: 包含 Authorization 和 Content-Type 的请求头
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
json_data: Optional[Dict] = None,
|
||||
params: Optional[Dict] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""统一请求封装。
|
||||
|
||||
Args:
|
||||
method: HTTP 方法(GET/POST/PUT/DELETE)
|
||||
path: API 路径(如 /api/v1/retrieval)
|
||||
json_data: JSON 请求体
|
||||
params: 查询参数
|
||||
timeout: 覆盖默认超时
|
||||
|
||||
Returns:
|
||||
Dict: API 响应的 JSON 数据
|
||||
|
||||
Raises:
|
||||
RagflowAuthError: 认证失败(401)
|
||||
RagflowApiError: API 返回错误
|
||||
RagflowConnectionError: 网络连接失败
|
||||
"""
|
||||
url = f"{self.base_url}{path}"
|
||||
req_timeout = timeout or self.timeout
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=self._headers(),
|
||||
json=json_data,
|
||||
params=params,
|
||||
timeout=req_timeout,
|
||||
)
|
||||
|
||||
# 处理 HTTP 错误
|
||||
if response.status_code == 401:
|
||||
raise RagflowAuthError("RAGFlow API Key 无效或已过期")
|
||||
|
||||
if response.status_code >= 400:
|
||||
try:
|
||||
err_body = response.json()
|
||||
err_msg = err_body.get("message", response.text)
|
||||
except Exception:
|
||||
err_msg = response.text
|
||||
raise RagflowApiError(
|
||||
code=response.status_code,
|
||||
message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}",
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# RAGFlow 统一响应格式:{code: 0, data: ..., message: ...}
|
||||
if result.get("code") != 0:
|
||||
raise RagflowApiError(
|
||||
code=result.get("code", -1),
|
||||
message=result.get("message", "未知错误"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except httpx.TimeoutException:
|
||||
raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}")
|
||||
except httpx.ConnectError:
|
||||
raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}")
|
||||
except (RagflowAuthError, RagflowApiError, RagflowConnectionError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RagflowError(f"RAGFlow 请求异常: {str(e)}")
|
||||
|
||||
# ==========================================================================
|
||||
# 测试连接
|
||||
# ==========================================================================
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试 RAGFlow API 连接。
|
||||
|
||||
通过列出数据集(limit=1)验证 API Key 是否有效。
|
||||
|
||||
Returns:
|
||||
Dict: {success: bool, message: str}
|
||||
"""
|
||||
try:
|
||||
result = await self.list_datasets(page=1, page_size=1)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"连接成功,共 {result.get('total', 0)} 个知识库",
|
||||
}
|
||||
except RagflowAuthError:
|
||||
return {"success": False, "message": "API Key 无效或已过期"}
|
||||
except RagflowConnectionError as e:
|
||||
return {"success": False, "message": f"连接失败: {e.message}"}
|
||||
except RagflowError as e:
|
||||
return {"success": False, "message": e.message}
|
||||
|
||||
# ==========================================================================
|
||||
# 知识检索(核心接口)
|
||||
# ==========================================================================
|
||||
|
||||
async def retrieval(
|
||||
self,
|
||||
question: str,
|
||||
dataset_ids: Optional[List[str]] = None,
|
||||
document_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.2,
|
||||
vector_similarity_weight: float = 0.3,
|
||||
top_k: int = 1024,
|
||||
keyword: bool = False,
|
||||
highlight: bool = False,
|
||||
) -> RetrievalResult:
|
||||
"""知识检索 — 从知识库中搜索相关文档片段。
|
||||
|
||||
这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。
|
||||
|
||||
Args:
|
||||
question: 用户查询问题
|
||||
dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一)
|
||||
document_ids: 要搜索的文档ID列表
|
||||
similarity_threshold: 最小相似度阈值(0-1),默认 0.2
|
||||
vector_similarity_weight: 向量相似度权重(0-1),默认 0.3
|
||||
top_k: 参与计算的块数量,默认 1024
|
||||
keyword: 是否启用关键词匹配,默认 False
|
||||
highlight: 是否高亮匹配术语,默认 False
|
||||
|
||||
Returns:
|
||||
RetrievalResult: 检索结果(含文本块、文档聚合、总数)
|
||||
|
||||
Raises:
|
||||
RagflowError: 检索失败
|
||||
"""
|
||||
body: Dict[str, Any] = {
|
||||
"question": question,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"vector_similarity_weight": vector_similarity_weight,
|
||||
"top_k": top_k,
|
||||
"keyword": keyword,
|
||||
"highlight": highlight,
|
||||
}
|
||||
|
||||
if dataset_ids:
|
||||
body["dataset_ids"] = dataset_ids
|
||||
if document_ids:
|
||||
body["document_ids"] = document_ids
|
||||
|
||||
result = await self._request("POST", "/api/v1/retrieval", json_data=body)
|
||||
|
||||
data = result.get("data", {})
|
||||
|
||||
# 解析文本块
|
||||
chunks = [
|
||||
RetrievalChunk.model_validate(chunk)
|
||||
for chunk in data.get("chunks", [])
|
||||
]
|
||||
|
||||
# 解析文档聚合
|
||||
doc_aggs = [
|
||||
DocAggregate.model_validate(agg)
|
||||
for agg in data.get("doc_aggs", [])
|
||||
]
|
||||
|
||||
return RetrievalResult(
|
||||
chunks=chunks,
|
||||
doc_aggs=doc_aggs,
|
||||
total=data.get("total", 0),
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# 数据集(知识库)管理
|
||||
# ==========================================================================
|
||||
|
||||
async def list_datasets(
|
||||
self,
|
||||
page: int = 1,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""列出所有数据集(知识库)。
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页条数
|
||||
|
||||
Returns:
|
||||
Dict: {items: List[DatasetInfo], total: int}
|
||||
"""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
"/api/v1/datasets",
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
items = [
|
||||
DatasetInfo.model_validate(ds)
|
||||
for ds in data.get("datasets", [])
|
||||
]
|
||||
|
||||
return {"items": items, "total": data.get("total", 0)}
|
||||
|
||||
async def create_dataset(
|
||||
self,
|
||||
name: str,
|
||||
embedding_model: str = "BAAI/bge-m3@BAAI",
|
||||
chunk_method: str = "naive",
|
||||
permission: str = "me",
|
||||
) -> DatasetInfo:
|
||||
"""创建数据集(知识库)。
|
||||
|
||||
Args:
|
||||
name: 数据集名称
|
||||
embedding_model: 向量模型
|
||||
chunk_method: 分块方法(naive/qa/book/laws 等)
|
||||
permission: 权限(me/team)
|
||||
|
||||
Returns:
|
||||
DatasetInfo: 创建的数据集信息
|
||||
"""
|
||||
body = {
|
||||
"name": name,
|
||||
"embedding_model": embedding_model,
|
||||
"chunk_method": chunk_method,
|
||||
"permission": permission,
|
||||
}
|
||||
|
||||
result = await self._request("POST", "/api/v1/datasets", json_data=body)
|
||||
return DatasetInfo.model_validate(result.get("data", {}))
|
||||
|
||||
async def delete_dataset(self, dataset_ids: List[str]) -> bool:
|
||||
"""删除数据集。
|
||||
|
||||
Args:
|
||||
dataset_ids: 要删除的数据集ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
"/api/v1/datasets",
|
||||
json_data={"ids": dataset_ids},
|
||||
)
|
||||
return True
|
||||
|
||||
# ==========================================================================
|
||||
# 文档管理
|
||||
# ==========================================================================
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
dataset_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = DEFAULT_PAGE_SIZE,
|
||||
) -> Dict[str, Any]:
|
||||
"""列出数据集中的文档。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
page: 页码
|
||||
page_size: 每页条数
|
||||
|
||||
Returns:
|
||||
Dict: {items: List[DocumentInfo], total: int}
|
||||
"""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/api/v1/datasets/{dataset_id}/documents",
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
items = [
|
||||
DocumentInfo.model_validate(doc)
|
||||
for doc in data.get("documents", [])
|
||||
]
|
||||
|
||||
return {"items": items, "total": data.get("total", 0)}
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
dataset_id: str,
|
||||
file_path: str,
|
||||
file_name: Optional[str] = None,
|
||||
) -> DocumentInfo:
|
||||
"""上传文档到数据集。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
file_path: 本地文件路径
|
||||
file_name: 文件名(可选,默认取 file_path 的文件名)
|
||||
|
||||
Returns:
|
||||
DocumentInfo: 上传的文档信息
|
||||
"""
|
||||
import os
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise RagflowError(f"文件不存在: {file_path}")
|
||||
|
||||
fname = file_name or os.path.basename(file_path)
|
||||
|
||||
url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(file_path, "rb") as f:
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
files={"file": (fname, f)},
|
||||
timeout=60.0,
|
||||
)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise RagflowAuthError()
|
||||
|
||||
result = response.json()
|
||||
if result.get("code") != 0:
|
||||
raise RagflowApiError(
|
||||
code=result.get("code", -1),
|
||||
message=result.get("message", "上传失败"),
|
||||
)
|
||||
|
||||
docs = result.get("data", {}).get("documents", [])
|
||||
if docs:
|
||||
return DocumentInfo.model_validate(docs[0])
|
||||
return DocumentInfo(name=fname)
|
||||
|
||||
except (RagflowAuthError, RagflowApiError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise RagflowError(f"文档上传失败: {str(e)}")
|
||||
|
||||
async def delete_documents(
|
||||
self,
|
||||
dataset_id: str,
|
||||
document_ids: List[str],
|
||||
) -> bool:
|
||||
"""删除文档。
|
||||
|
||||
Args:
|
||||
dataset_id: 数据集ID
|
||||
document_ids: 要删除的文档ID列表
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/api/v1/datasets/{dataset_id}/documents",
|
||||
json_data={"ids": document_ids},
|
||||
)
|
||||
return True
|
||||
Reference in New Issue
Block a user