ragas_eval_with_rag_api.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. """
  2. RAGAS评估脚本 - 集成RAG检索功能 - FastAPI 版本
  3. 基于 ragas_eval_with_rag.py 创建的 API 服务
  4. 功能:
  5. - 上传 JSON 文件(只需 question 和 ground_truth)
  6. - 自动通过 RAG 获取 contexts 和 answer
  7. - 使用全部 5 个 RAGAS 指标进行评估
  8. - 返回详细的评估结果和统计
  9. """
  10. import os
  11. import sys
  12. import json
  13. import pandas as pd
  14. import asyncio
  15. import math
  16. from datasets import Dataset
  17. from typing import List, Dict, Any
  18. # FastAPI 相关
  19. from fastapi import FastAPI, UploadFile, File, HTTPException, Form
  20. from fastapi.middleware.cors import CORSMiddleware
  21. from fastapi.responses import FileResponse, JSONResponse
  22. from fastapi.staticfiles import StaticFiles
  23. import uvicorn
  24. # 添加项目根目录到 Python 路径
  25. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  26. sys.path.insert(0, project_root)
  27. # RAGAS 评估相关
  28. from ragas import evaluate
  29. from ragas.metrics import (
  30. faithfulness,
  31. answer_correctness,
  32. answer_relevancy,
  33. context_precision,
  34. context_recall,
  35. NoiseSensitivity
  36. )
  37. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  38. # RAG 相关导入
  39. from rag.chat_message import ChatRetrieverRag
  40. from rag.llm import VllmApi
  41. # 配置 VLLM 和 Embedding 服务
  42. VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
  43. VLLM_LLM_KEY = "vllm-dummy-key"
  44. VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
  45. # 初始化 RAGAS 评估使用的模型
  46. vllm_generator = ChatOpenAI(
  47. model="Qwen3-Coder-30B-loft",
  48. base_url=VLLM_LLM_BASE,
  49. api_key=VLLM_LLM_KEY,
  50. temperature=0.1,
  51. max_tokens=51200,
  52. timeout=3000,
  53. request_timeout=3000,
  54. max_retries=5,
  55. )
  56. vllm_embeddings = OpenAIEmbeddings(
  57. model="",
  58. base_url=VLLM_EMBEDDING_BASE,
  59. api_key=VLLM_LLM_KEY,
  60. )
  61. # 定义评估指标
  62. metrics_to_run = [
  63. faithfulness,
  64. answer_correctness,
  65. answer_relevancy,
  66. context_precision,
  67. context_recall,
  68. NoiseSensitivity(llm=vllm_generator)
  69. ]
  70. class RAGEvaluator:
  71. """集成 RAG 功能的评估器"""
  72. def __init__(
  73. self,
  74. knowledge_ids: List[str],
  75. embedding_id: str = "e5",
  76. temperature: float = 0.6,
  77. top_p: float = 0.7,
  78. max_tokens: int = 4096,
  79. model: str = "Qwen3-Coder-30B-loft",
  80. enable_think: bool = False,
  81. slice_count: int = 5
  82. ):
  83. """初始化评估器
  84. Args:
  85. knowledge_ids: 知识库 ID 列表
  86. embedding_id: 嵌入模型 ID
  87. temperature: LLM 温度参数
  88. top_p: LLM top_p 参数
  89. max_tokens: 最大生成 token 数
  90. model: LLM 模型名称
  91. enable_think: 是否启用思考模式
  92. slice_count: 检索切片数量
  93. """
  94. self.knowledge_ids = knowledge_ids
  95. self.embedding_id = embedding_id
  96. self.temperature = temperature
  97. self.top_p = top_p
  98. self.max_tokens = max_tokens
  99. self.model = model
  100. self.enable_think = enable_think
  101. self.slice_count = slice_count
  102. self.rag_config = {
  103. "knowledgeIds": knowledge_ids,
  104. "embeddingId": embedding_id,
  105. "sliceCount": slice_count,
  106. "knowledgeInfo": json.dumps({
  107. "recall_method": "mixed",
  108. "rerank_status": True,
  109. "rerank_model_name": "bce_rerank_model"
  110. }),
  111. "temperature": temperature,
  112. "topP": top_p,
  113. "maxToken": max_tokens,
  114. "enable_think": enable_think,
  115. "prompt": """你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{知识}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:"进入登录页面(图1.1)"需要从文字中删除图序,回复效果为:"进入登录页面";"如图所示1.1",回复效果为:"如图所示";\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:"我想","我认为","think"等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{用户}"""
  116. }
  117. async def get_rag_response(self, question: str) -> Dict[str, Any]:
  118. """通过 RAG 系统获取答案和上下文"""
  119. chat_json = self.rag_config.copy()
  120. chat_json["query"] = question
  121. rag_retriever = ChatRetrieverRag(chat_json=chat_json)
  122. retriever_result_list, search_doc_id_to_knowledge_id_dict = rag_retriever.retriever_result(chat_json)
  123. chunk_content, knowledge_info_dict = rag_retriever.parse_retriever_list(
  124. retriever_result_list,
  125. search_doc_id_to_knowledge_id_dict
  126. )
  127. contexts = [item["content"] for item in retriever_result_list]
  128. answer = ""
  129. async for event in rag_retriever.generate_rag_response(chat_json, chunk_content):
  130. if event.get("event") == "add":
  131. answer = event.get("data", "")
  132. elif event.get("event") == "finish":
  133. break
  134. return {
  135. "answer": answer,
  136. "contexts": contexts,
  137. "retriever_count": len(retriever_result_list)
  138. }
  139. async def process_qa_data(self, qa_data: List[Dict[str, str]]) -> List[Dict[str, Any]]:
  140. """处理问答数据,通过 RAG 获取 contexts 和 answer"""
  141. full_data = []
  142. for idx, item in enumerate(qa_data):
  143. question = item.get("question", "")
  144. ground_truth = item.get("ground_truth", "")
  145. print(f"处理第 {idx + 1}/{len(qa_data)} 个问题: {question[:50]}...")
  146. try:
  147. rag_result = await self.get_rag_response(question)
  148. full_item = {
  149. "question": question,
  150. "contexts": rag_result["contexts"],
  151. "answer": rag_result["answer"],
  152. "ground_truth": ground_truth,
  153. }
  154. full_data.append(full_item)
  155. print(f" ✓ 检索到 {rag_result['retriever_count']} 个上下文")
  156. except Exception as e:
  157. print(f" ✗ 处理失败: {e}")
  158. full_data.append({
  159. "question": question,
  160. "contexts": [],
  161. "answer": "",
  162. "ground_truth": ground_truth,
  163. })
  164. return full_data
  165. def _clean_float_values(obj):
  166. """清理非法浮点数值(NaN, Infinity)为 None"""
  167. if isinstance(obj, float):
  168. if math.isnan(obj) or math.isinf(obj):
  169. return None
  170. return obj
  171. elif isinstance(obj, dict):
  172. return {k: _clean_float_values(v) for k, v in obj.items()}
  173. elif isinstance(obj, list):
  174. return [_clean_float_values(item) for item in obj]
  175. return obj
  176. def run_evaluation(dataset: Dataset) -> Dict[str, Any]:
  177. """执行 RAGAS 评估"""
  178. print("\n" + "="*60)
  179. print("开始 RAGAS 评估...")
  180. print("="*60)
  181. final_metrics = list(metrics_to_run)
  182. if 'ground_truth' not in dataset.column_names:
  183. print("警告: 缺少 'ground_truth',将跳过 'context_recall' 指标")
  184. final_metrics.remove(context_recall)
  185. print(f"使用指标: {[m.name for m in final_metrics]}")
  186. print(f"数据集大小: {len(dataset)} 条\n")
  187. try:
  188. result = evaluate(
  189. dataset=dataset,
  190. metrics=final_metrics,
  191. llm=vllm_generator,
  192. embeddings=vllm_embeddings,
  193. raise_exceptions=False,
  194. show_progress=True,
  195. )
  196. except Exception as e:
  197. print(f"\n评估过程中出现错误: {e}")
  198. raise
  199. result_df = result.to_pandas()
  200. # 替换 DataFrame 中的 NaN 和 Infinity 为 None
  201. result_df = result_df.replace([float('inf'), float('-inf')], None)
  202. result_df = result_df.where(pd.notna(result_df), None)
  203. # 计算平均指标(仅计算真正的评估指标)
  204. metric_names = [m.name for m in final_metrics]
  205. summary = {}
  206. for metric_name in metric_names:
  207. if metric_name in result_df.columns:
  208. try:
  209. mean_val = result_df[metric_name].mean(skipna=True)
  210. if pd.notna(mean_val) and math.isfinite(mean_val):
  211. summary[metric_name] = float(mean_val)
  212. else:
  213. summary[metric_name] = None
  214. except Exception:
  215. summary[metric_name] = None
  216. print("\n" + "="*60)
  217. print("评估完成!")
  218. print("="*60)
  219. # 转换为字典并清理
  220. rows = result_df.to_dict(orient="records")
  221. return {
  222. "rows": _clean_float_values(rows), # 清理 NaN 和 Infinity, rows是评估结果的详细信息
  223. "summary": _clean_float_values(summary), # 清理 NaN 和 Infinity, summary是评估结果的汇总信息
  224. "metrics": [m.name for m in final_metrics], # final_metrics是评估指标
  225. "count": len(result_df), # 评估指标数
  226. }
  227. # FastAPI 应用
  228. app = FastAPI(
  229. title="RAGAS RAG 评估 API",
  230. description="集成 RAG 检索功能的 RAGAS 评估服务",
  231. version="1.0.0"
  232. )
  233. app.add_middleware(
  234. CORSMiddleware,
  235. allow_origins=["*"],
  236. allow_credentials=True,
  237. allow_methods=["*"],
  238. allow_headers=["*"],
  239. )
  240. # 挂载前端静态文件
  241. frontend_dir = os.path.join(os.path.dirname(__file__), "frontend")
  242. if os.path.isdir(frontend_dir):
  243. app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
  244. @app.get("/")
  245. def index():
  246. """首页"""
  247. # 优先使用 rag_eval.html
  248. rag_eval_path = os.path.join(frontend_dir, "rag_eval.html")
  249. if os.path.isfile(rag_eval_path):
  250. return FileResponse(rag_eval_path)
  251. # 备用 index.html
  252. index_path = os.path.join(frontend_dir, "index.html")
  253. if os.path.isfile(index_path):
  254. return FileResponse(index_path)
  255. return JSONResponse({
  256. "message": "RAGAS RAG 评估 API",
  257. "docs": "/docs",
  258. "version": "1.0.0"
  259. })
  260. @app.post("/api/evaluate")
  261. async def api_evaluate(
  262. file: UploadFile = File(...),
  263. knowledge_ids: str = Form(...),
  264. embedding_id: str = Form("e5"),
  265. temperature: float = Form(0.6),
  266. top_p: float = Form(0.7),
  267. max_tokens: int = Form(4096),
  268. model: str = Form("Qwen3-Coder-30B-loft"),
  269. enable_think: bool = Form(False),
  270. slice_count: int = Form(5)
  271. ):
  272. """
  273. 评估接口
  274. 参数:
  275. - file: JSON 文件(包含 question 和 ground_truth)
  276. - knowledge_ids: 知识库 ID,多个用逗号分隔
  277. - embedding_id: 嵌入模型 ID(默认 e5)
  278. - temperature: LLM 温度参数(默认 0.6)
  279. - top_p: LLM top_p 参数(默认 0.7)
  280. - max_tokens: 最大生成 token 数(默认 4096)
  281. - model: LLM 模型名称(默认 Qwen3-Coder-30B-loft)
  282. - enable_think: 是否启用思考模式(默认 False)
  283. - slice_count: 检索切片数量(默认 5)
  284. """
  285. if not file.filename.lower().endswith(".json"):
  286. raise HTTPException(status_code=400, detail="仅支持 .json 文件")
  287. try:
  288. # 解析知识库 ID
  289. knowledge_id_list = [kid.strip() for kid in knowledge_ids.split(",") if kid.strip()]
  290. if not knowledge_id_list:
  291. raise ValueError("knowledge_ids 不能为空")
  292. # 读取文件
  293. content = await file.read()
  294. qa_data = json.loads(content.decode("utf-8"))
  295. if not isinstance(qa_data, list):
  296. raise ValueError("JSON 必须是数组格式")
  297. if not qa_data:
  298. raise ValueError("数据不能为空")
  299. # 初始化评估器
  300. print(f"\n配置: 模型={model}, 温度={temperature}, top_p={top_p}, 切片数={slice_count}")
  301. evaluator = RAGEvaluator(
  302. knowledge_ids=knowledge_id_list,
  303. embedding_id=embedding_id,
  304. temperature=temperature,
  305. top_p=top_p,
  306. max_tokens=max_tokens,
  307. model=model,
  308. enable_think=enable_think,
  309. slice_count=slice_count
  310. )
  311. # 处理数据
  312. print(f"\n开始处理 {len(qa_data)} 条数据...")
  313. full_data = await evaluator.process_qa_data(qa_data)
  314. # 创建 Dataset
  315. dataset = Dataset.from_dict({
  316. "question": [item["question"] for item in full_data],
  317. "contexts": [item["contexts"] for item in full_data],
  318. "answer": [item["answer"] for item in full_data],
  319. "ground_truth": [item["ground_truth"] for item in full_data],
  320. })
  321. # 执行评估
  322. result = run_evaluation(dataset)
  323. return JSONResponse(result)
  324. except json.JSONDecodeError as e:
  325. raise HTTPException(status_code=400, detail=f"JSON 解析失败: {e}")
  326. except Exception as e:
  327. import traceback
  328. error_detail = f"评估失败: {e}\n{traceback.format_exc()}"
  329. print(error_detail)
  330. raise HTTPException(status_code=500, detail=str(e))
  331. if __name__ == "__main__":
  332. port = int(os.environ.get("PORT", "8001"))
  333. print(f"启动 RAGAS RAG 评估 API 服务...")
  334. print(f"访问地址: http://0.0.0.0:{port}")
  335. print(f"API 文档: http://0.0.0.0:{port}/docs")
  336. uvicorn.run(
  337. "ragas_eval_with_rag_api:app",
  338. host="0.0.0.0",
  339. port=port,
  340. reload=False
  341. )