# ============================================================================= # RAGFlow API 客户端 # ============================================================================= # 说明:封装 RAGFlow 知识检索引擎的 API 调用 # 核心功能: # 1. 知识检索 — POST /api/v1/retrieval(核心接口) # 2. 数据集管理 — 列出/创建/删除知识库 # 3. 文档管理 — 上传/列出/删除文档 # 4. 测试连接 — 验证 API Key 是否有效 # 认证方式:Authorization: Bearer # 参考文档:https://ragflow.io/docs/http_api_reference # ============================================================================= import logging from typing import Any, Dict, List, Optional import httpx from .exceptions import ( RagflowApiError, RagflowAuthError, RagflowConfigError, RagflowConnectionError, RagflowError, ) from .models import ( DatasetInfo, DocAggregate, DocumentInfo, RetrievalChunk, RetrievalResult, ) logger = logging.getLogger(__name__) # 默认请求超时(秒) DEFAULT_TIMEOUT = 30.0 # 默认分页大小 DEFAULT_PAGE_SIZE = 20 class RagflowClient: """RAGFlow API 客户端。 封装 RAGFlow 知识检索引擎的 API 调用,支持: - 知识检索(核心功能) - 数据集(知识库)管理 - 文档管理 - 连接测试 使用方式: client = RagflowClient( api_key="sk-xxx", base_url="http://10.80.0.85:9380" ) result = await client.retrieval("VPN怎么连?", dataset_ids=["xxx"]) """ def __init__( self, api_key: str, base_url: str = "http://10.80.0.85:9380", timeout: float = DEFAULT_TIMEOUT, ): """初始化 RAGFlow 客户端。 Args: api_key: RAGFlow API Key(Bearer Token) base_url: RAGFlow API 基础地址(不含尾部斜杠) timeout: 默认请求超时(秒) Raises: RagflowConfigError: API Key 为空 """ if not api_key: raise RagflowConfigError("RAGFlow API Key 不能为空") self.api_key = api_key self.base_url = base_url.rstrip("/") self.timeout = timeout def _headers(self) -> Dict[str, str]: """构建请求头。 Returns: Dict: 包含 Authorization 和 Content-Type 的请求头 """ return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } async def _request( self, method: str, path: str, json_data: Optional[Dict] = None, params: Optional[Dict] = None, timeout: Optional[float] = None, ) -> Dict[str, Any]: """统一请求封装。 Args: method: HTTP 方法(GET/POST/PUT/DELETE) path: API 路径(如 /api/v1/retrieval) json_data: JSON 请求体 params: 查询参数 timeout: 覆盖默认超时 Returns: Dict: API 响应的 JSON 数据 Raises: RagflowAuthError: 认证失败(401) RagflowApiError: API 返回错误 RagflowConnectionError: 网络连接失败 """ url = f"{self.base_url}{path}" req_timeout = timeout or self.timeout try: async with httpx.AsyncClient() as client: response = await client.request( method=method, url=url, headers=self._headers(), json=json_data, params=params, timeout=req_timeout, ) # 处理 HTTP 错误 if response.status_code == 401: raise RagflowAuthError("RAGFlow API Key 无效或已过期") if response.status_code >= 400: try: err_body = response.json() err_msg = err_body.get("message", response.text) except Exception: err_msg = response.text raise RagflowApiError( code=response.status_code, message=f"RAGFlow API 错误 ({response.status_code}): {err_msg}", ) # 解析响应 result = response.json() # RAGFlow 统一响应格式:{code: 0, data: ..., message: ...} if result.get("code") != 0: raise RagflowApiError( code=result.get("code", -1), message=result.get("message", "未知错误"), ) return result except httpx.TimeoutException: raise RagflowConnectionError(f"RAGFlow 请求超时 ({req_timeout}s): {path}") except httpx.ConnectError: raise RagflowConnectionError(f"RAGFlow 连接失败: {self.base_url}") except (RagflowAuthError, RagflowApiError, RagflowConnectionError): raise except Exception as e: raise RagflowError(f"RAGFlow 请求异常: {str(e)}") # ========================================================================== # 测试连接 # ========================================================================== async def test_connection(self) -> Dict[str, Any]: """测试 RAGFlow API 连接。 通过列出数据集(limit=1)验证 API Key 是否有效。 Returns: Dict: {success: bool, message: str} """ try: result = await self.list_datasets(page=1, page_size=1) return { "success": True, "message": f"连接成功,共 {result.get('total', 0)} 个知识库", } except RagflowAuthError: return {"success": False, "message": "API Key 无效或已过期"} except RagflowConnectionError as e: return {"success": False, "message": f"连接失败: {e.message}"} except RagflowError as e: return {"success": False, "message": e.message} # ========================================================================== # 知识检索(核心接口) # ========================================================================== async def retrieval( self, question: str, dataset_ids: Optional[List[str]] = None, document_ids: Optional[List[str]] = None, similarity_threshold: float = 0.2, vector_similarity_weight: float = 0.3, top_k: int = 1024, keyword: bool = False, highlight: bool = False, ) -> RetrievalResult: """知识检索 — 从知识库中搜索相关文档片段。 这是 RAGFlow 的核心接口,用于根据用户问题检索最相关的文本块。 Args: question: 用户查询问题 dataset_ids: 要搜索的数据集ID列表(与 document_ids 二选一) document_ids: 要搜索的文档ID列表 similarity_threshold: 最小相似度阈值(0-1),默认 0.2 vector_similarity_weight: 向量相似度权重(0-1),默认 0.3 top_k: 参与计算的块数量,默认 1024 keyword: 是否启用关键词匹配,默认 False highlight: 是否高亮匹配术语,默认 False Returns: RetrievalResult: 检索结果(含文本块、文档聚合、总数) Raises: RagflowError: 检索失败 """ body: Dict[str, Any] = { "question": question, "similarity_threshold": similarity_threshold, "vector_similarity_weight": vector_similarity_weight, "top_k": top_k, "keyword": keyword, "highlight": highlight, } if dataset_ids: body["dataset_ids"] = dataset_ids if document_ids: body["document_ids"] = document_ids result = await self._request("POST", "/api/v1/retrieval", json_data=body) data = result.get("data", {}) # 解析文本块 chunks = [ RetrievalChunk.model_validate(chunk) for chunk in data.get("chunks", []) ] # 解析文档聚合 doc_aggs = [ DocAggregate.model_validate(agg) for agg in data.get("doc_aggs", []) ] return RetrievalResult( chunks=chunks, doc_aggs=doc_aggs, total=data.get("total", 0), ) # ========================================================================== # 数据集(知识库)管理 # ========================================================================== async def list_datasets( self, page: int = 1, page_size: int = DEFAULT_PAGE_SIZE, ) -> Dict[str, Any]: """列出所有数据集(知识库)。 Args: page: 页码 page_size: 每页条数 Returns: Dict: {items: List[DatasetInfo], total: int} """ result = await self._request( "GET", "/api/v1/datasets", params={"page": page, "page_size": page_size}, ) data = result.get("data", {}) items = [ DatasetInfo.model_validate(ds) for ds in data.get("datasets", []) ] return {"items": items, "total": data.get("total", 0)} async def create_dataset( self, name: str, embedding_model: str = "BAAI/bge-m3@BAAI", chunk_method: str = "naive", permission: str = "me", ) -> DatasetInfo: """创建数据集(知识库)。 Args: name: 数据集名称 embedding_model: 向量模型 chunk_method: 分块方法(naive/qa/book/laws 等) permission: 权限(me/team) Returns: DatasetInfo: 创建的数据集信息 """ body = { "name": name, "embedding_model": embedding_model, "chunk_method": chunk_method, "permission": permission, } result = await self._request("POST", "/api/v1/datasets", json_data=body) return DatasetInfo.model_validate(result.get("data", {})) async def delete_dataset(self, dataset_ids: List[str]) -> bool: """删除数据集。 Args: dataset_ids: 要删除的数据集ID列表 Returns: bool: 是否成功 """ await self._request( "DELETE", "/api/v1/datasets", json_data={"ids": dataset_ids}, ) return True # ========================================================================== # 文档管理 # ========================================================================== async def list_documents( self, dataset_id: str, page: int = 1, page_size: int = DEFAULT_PAGE_SIZE, ) -> Dict[str, Any]: """列出数据集中的文档。 Args: dataset_id: 数据集ID page: 页码 page_size: 每页条数 Returns: Dict: {items: List[DocumentInfo], total: int} """ result = await self._request( "GET", f"/api/v1/datasets/{dataset_id}/documents", params={"page": page, "page_size": page_size}, ) data = result.get("data", {}) items = [ DocumentInfo.model_validate(doc) for doc in data.get("documents", []) ] return {"items": items, "total": data.get("total", 0)} async def upload_document( self, dataset_id: str, file_path: str, file_name: Optional[str] = None, ) -> DocumentInfo: """上传文档到数据集。 Args: dataset_id: 数据集ID file_path: 本地文件路径 file_name: 文件名(可选,默认取 file_path 的文件名) Returns: DocumentInfo: 上传的文档信息 """ import os if not os.path.exists(file_path): raise RagflowError(f"文件不存在: {file_path}") fname = file_name or os.path.basename(file_path) url = f"{self.base_url}/api/v1/datasets/{dataset_id}/documents" try: async with httpx.AsyncClient() as client: with open(file_path, "rb") as f: response = await client.post( url=url, headers={"Authorization": f"Bearer {self.api_key}"}, files={"file": (fname, f)}, timeout=60.0, ) if response.status_code == 401: raise RagflowAuthError() result = response.json() if result.get("code") != 0: raise RagflowApiError( code=result.get("code", -1), message=result.get("message", "上传失败"), ) docs = result.get("data", {}).get("documents", []) if docs: return DocumentInfo.model_validate(docs[0]) return DocumentInfo(name=fname) except (RagflowAuthError, RagflowApiError): raise except Exception as e: raise RagflowError(f"文档上传失败: {str(e)}") async def delete_documents( self, dataset_id: str, document_ids: List[str], ) -> bool: """删除文档。 Args: dataset_id: 数据集ID document_ids: 要删除的文档ID列表 Returns: bool: 是否成功 """ await self._request( "DELETE", f"/api/v1/datasets/{dataset_id}/documents", json_data={"ids": document_ids}, ) return True