ragas_eval.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. import os
  2. import json
  3. import pandas as pd
  4. from datasets import Dataset
  5. from typing import List, Dict, Any
  6. # FastAPI 相关
  7. from fastapi import FastAPI, UploadFile, File, HTTPException
  8. from fastapi.middleware.cors import CORSMiddleware
  9. from fastapi.responses import FileResponse, JSONResponse
  10. from fastapi.staticfiles import StaticFiles
  11. import uvicorn
  12. os.environ["RUN_SERVER"] = "True"
  13. from ragas import evaluate
  14. from ragas.metrics import (
  15. faithfulness,
  16. answer_correctness,
  17. answer_relevancy,
  18. context_precision,
  19. context_recall,
  20. )
  21. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  22. VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
  23. VLLM_LLM_KEY = "vllm-dummy-key"
  24. VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
  25. vllm_generator = ChatOpenAI(
  26. model="Qwen3-Coder-30B-loft",
  27. base_url=VLLM_LLM_BASE,
  28. api_key=VLLM_LLM_KEY,
  29. temperature=0,
  30. max_tokens=512,
  31. )
  32. vllm_embeddings = OpenAIEmbeddings(
  33. model="",
  34. base_url=VLLM_EMBEDDING_BASE,
  35. api_key=VLLM_LLM_KEY,
  36. )
  37. # --- 2. 加载和准备数据集 ---
  38. def load_data_from_json(json_path):
  39. """
  40. 从 JSON 文件加载数据并转换为 datasets.Dataset 格式。
  41. 假设的 JSON 结构:
  42. [
  43. {
  44. "question": "...",
  45. "answer": "...",
  46. "contexts": ["...", "..."],
  47. "ground_truth": "..."
  48. },
  49. ...
  50. ]
  51. """
  52. try:
  53. with open(json_path, 'r', encoding='utf-8') as f:
  54. data_list = json.load(f)
  55. df = pd.DataFrame(data_list)
  56. required_cols = {'question', 'answer', 'contexts'}
  57. if not required_cols.issubset(df.columns):
  58. raise ValueError(f"JSON 文件必须包含 'question', 'answer', 和 'contexts' 键")
  59. if 'ground_truth' not in df.columns:
  60. print("警告: 未找到 'ground_truth' 列。'context_recall' 指标将无法计算。")
  61. dataset = Dataset.from_pandas(df)
  62. return dataset
  63. except Exception as e:
  64. print(f"加载数据时出错: {e}")
  65. return None
  66. def load_data_from_list(data_list: List[Dict[str, Any]]):
  67. """
  68. 从字典列表加载数据并转换为 datasets.Dataset。
  69. 该函数用于 FastAPI 文件上传的场景。
  70. """
  71. try:
  72. df = pd.DataFrame(data_list)
  73. required_cols = {'question', 'answer', 'contexts'}
  74. if not required_cols.issubset(df.columns):
  75. raise ValueError("JSON 必须包含 'question', 'answer', 'contexts' 键")
  76. if 'ground_truth' not in df.columns:
  77. print("警告: 未找到 'ground_truth' 列。'context_recall' 指标将无法计算。")
  78. dataset = Dataset.from_pandas(df)
  79. return dataset
  80. except Exception as e:
  81. raise ValueError(f"加载上传数据时出错: {e}")
  82. # --- 3. 定义指标 ---
  83. metrics_to_run = [
  84. faithfulness,
  85. answer_correctness,
  86. answer_relevancy,
  87. context_precision,
  88. context_recall,
  89. ]
  90. # --- 4. 执行评估 ---
  91. def run_evaluation(dataset_path):
  92. print(f"正在从 {dataset_path} 加载数据...")
  93. dataset = load_data_from_json(dataset_path)
  94. if dataset is None:
  95. print("数据加载失败,退出评估。")
  96. return
  97. print(f"成功加载 {len(dataset)} 条评估数据。")
  98. final_metrics = list(metrics_to_run)
  99. if 'ground_truth' not in dataset.column_names:
  100. print("由于缺少 'ground_truth',将跳过 'context_recall' 指标。")
  101. final_metrics.remove(context_recall)
  102. if not final_metrics:
  103. print("没有可执行的指标。退出。")
  104. return
  105. print("开始 RAGAS 评估... (这可能需要一些时间)")
  106. # 核心:调用 evaluate
  107. result = evaluate(
  108. dataset=dataset,
  109. metrics=final_metrics,
  110. llm=vllm_generator,
  111. embeddings=vllm_embeddings
  112. )
  113. print("评估完成!")
  114. result_df = result.to_pandas()
  115. print(json.dumps(result_df.to_dict(orient="records"), indent=4, ensure_ascii=False))
  116. def summarize_metrics(result_df: pd.DataFrame) -> Dict[str, float]:
  117. """
  118. 汇总各指标的平均值,便于前端展示。
  119. """
  120. summary = {}
  121. for col in [
  122. 'faithfulness',
  123. 'answer_correctness',
  124. 'answer_relevancy',
  125. 'context_precision',
  126. 'context_recall',
  127. ]:
  128. if col in result_df.columns:
  129. try:
  130. summary[col] = float(result_df[col].mean())
  131. except Exception:
  132. pass
  133. return summary
  134. def evaluate_dataset(dataset: Dataset, final_metrics: List = None) -> Dict[str, Any]:
  135. """
  136. 针对传入的 Dataset 执行评估,返回行结果与汇总。
  137. """
  138. if final_metrics is None:
  139. final_metrics = list(metrics_to_run)
  140. # 缺 ground_truth 时移除 context_recall
  141. if 'ground_truth' not in dataset.column_names and context_recall in final_metrics:
  142. final_metrics.remove(context_recall)
  143. result = evaluate(
  144. dataset=dataset,
  145. metrics=final_metrics,
  146. llm=vllm_generator,
  147. embeddings=vllm_embeddings,
  148. )
  149. result_df = result.to_pandas()
  150. print(json.dumps(result_df.to_dict(orient="records"), indent=4, ensure_ascii=False))
  151. return {
  152. "rows": result_df.to_dict(orient="records"),
  153. "summary": summarize_metrics(result_df),
  154. "metrics": [m.name for m in final_metrics],
  155. "count": len(result_df),
  156. }
  157. # --- FastAPI 应用 ---
  158. app = FastAPI()
  159. # 允许本地前端访问
  160. app.add_middleware(
  161. CORSMiddleware,
  162. allow_origins=["*"],
  163. allow_credentials=True,
  164. allow_methods=["*"],
  165. allow_headers=["*"],
  166. )
  167. # 挂载前端静态目录
  168. frontend_dir = os.path.join(os.path.dirname(__file__), "frontend")
  169. if os.path.isdir(frontend_dir):
  170. app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
  171. @app.get("/")
  172. def index():
  173. # 如果存在前端页面则返回
  174. index_path = os.path.join(frontend_dir, "index.html")
  175. if os.path.isfile(index_path):
  176. return FileResponse(index_path)
  177. return JSONResponse({"message": "前端未创建,请访问 /api/docs 使用接口。"})
  178. @app.post("/api/evaluate")
  179. async def api_evaluate(file: UploadFile = File(...)):
  180. """
  181. 上传一个 JSON 文件(数组,每项包含 question/answer/contexts/ground_truth 可选),执行评估。
  182. """
  183. if not file.filename.lower().endswith(".json"):
  184. raise HTTPException(status_code=400, detail="仅支持上传 .json 文件")
  185. try:
  186. content = await file.read()
  187. data_list = json.loads(content.decode("utf-8"))
  188. if not isinstance(data_list, list):
  189. raise ValueError("JSON 顶层必须为数组")
  190. dataset = load_data_from_list(data_list)
  191. # ground_truth 缺失时提示前端
  192. final_metrics = list(metrics_to_run)
  193. if 'ground_truth' not in dataset.column_names and context_recall in final_metrics:
  194. final_metrics.remove(context_recall)
  195. skipped = ["context_recall"]
  196. else:
  197. skipped = []
  198. result_payload = evaluate_dataset(dataset, final_metrics=final_metrics)
  199. result_payload["skipped_metrics"] = skipped
  200. # print(result_payload)
  201. return JSONResponse(result_payload)
  202. except Exception as e:
  203. raise HTTPException(status_code=400, detail=f"评估失败: {e}")
  204. # --- 主程序入口 ---
  205. if __name__ == "__main__":
  206. # 命令行模式:仍支持直接评估示例文件
  207. input_json_file = "dataset.json"
  208. # 如果设置了环境变量 RUN_SERVER=true,则启动服务
  209. run_server = os.environ.get("RUN_SERVER", "").lower() in {"1", "true", "yes"}
  210. if run_server:
  211. # 可通过环境变量 RELOAD 控制是否开启热重载,默认关闭以避免重载中断大型库导入
  212. reload_enabled = os.environ.get("RELOAD", "").lower() in {"1", "true", "yes"}
  213. uvicorn.run("ragas_eval:app", host="0.0.0.0", port=8000, reload=reload_enabled)
  214. else:
  215. if not os.path.exists(input_json_file):
  216. print(f"错误: 找不到 {input_json_file}")
  217. print("请创建一个示例 JSON 文件,结构如下:")
  218. print(
  219. """
  220. [
  221. {
  222. "question": "什么是 RAGAS?",
  223. "answer": "RAGAS 是一个评估 RAG 管道的框架。",
  224. "contexts": [
  225. "RAGAS (Retrieval-Augmented Generation Assessment) 是一个用于评估 RAG 管道的框架。"
  226. ],
  227. "ground_truth": "RAGAS 是一个专为评估 RAG 管道设计的框架,它关注检索和生成的质量。"
  228. }
  229. ]
  230. """
  231. )
  232. else:
  233. run_evaluation(input_json_file)