ragas_eval_with_rag.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. """
  2. RAGAS评估脚本 - 集成RAG检索功能
  3. 输入:只需要 question 和 ground_truth 的 JSON 文件
  4. 功能:通过 RAG 系统自动获取 contexts 和 answer,然后进行评估
  5. 说明:
  6. - 使用全部 5 个 RAGAS 指标进行评估
  7. - 通过增加超时时间和重试机制来减少 NaN 值
  8. - NaN 值通常由超时或网络问题导致,而非指标本身问题
  9. """
  10. import os
  11. import sys
  12. import json
  13. import pandas as pd
  14. import asyncio
  15. from datasets import Dataset
  16. from typing import List, Dict, Any
  17. # 添加项目根目录到 Python 路径
  18. project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
  19. sys.path.insert(0, project_root)
  20. # RAGAS 评估相关
  21. from ragas import evaluate
  22. from ragas.metrics import (
  23. faithfulness,
  24. answer_correctness,
  25. answer_relevancy,
  26. context_precision,
  27. context_recall,
  28. )
  29. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  30. # RAG 相关导入
  31. from rag.chat_message import ChatRetrieverRag
  32. from rag.llm import VllmApi
  33. # 配置 VLLM 和 Embedding 服务
  34. VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
  35. VLLM_LLM_KEY = "vllm-dummy-key"
  36. VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
  37. # 初始化 RAGAS 评估使用的模型
  38. vllm_generator = ChatOpenAI(
  39. model="Qwen3-Coder-30B-loft",
  40. base_url=VLLM_LLM_BASE,
  41. api_key=VLLM_LLM_KEY,
  42. temperature=0.1, # 降低温度提高稳定性
  43. max_tokens=51200,
  44. timeout=3000, # 增加超时到 300 秒
  45. request_timeout=3000,
  46. max_retries=5, # 增加重试次数
  47. )
  48. vllm_embeddings = OpenAIEmbeddings(
  49. model="",
  50. base_url=VLLM_EMBEDDING_BASE,
  51. api_key=VLLM_LLM_KEY,
  52. )
  53. # 定义评估指标 - 使用全部 5 个指标
  54. metrics_to_run = [
  55. faithfulness, # 忠实度:答案是否基于上下文
  56. answer_correctness, # 答案正确性:答案与标准答案的相似度
  57. answer_relevancy, # 答案相关性:答案是否回答了问题
  58. context_precision, # 上下文精确度:检索上下文的精确度
  59. context_recall, # 上下文召回:上下文是否包含答案信息
  60. ]
  61. class RAGEvaluator:
  62. """集成 RAG 功能的评估器"""
  63. def __init__(self, knowledge_ids: List[str], embedding_id: str = "e5"):
  64. """
  65. 初始化评估器
  66. Args:
  67. knowledge_ids: 知识库 ID 列表
  68. embedding_id: 嵌入模型 ID (可选值: "e5", "multilingual-e5-large-instruct")
  69. """
  70. self.knowledge_ids = knowledge_ids
  71. self.embedding_id = embedding_id
  72. # RAG 配置参数
  73. self.rag_config = {
  74. "knowledgeIds": knowledge_ids,
  75. "embeddingId": embedding_id,
  76. "sliceCount": 5, # 检索的切片数量
  77. "knowledgeInfo": json.dumps({
  78. "recall_method": "mixed", # 检索模式:embedding/keyword/mixed
  79. "rerank_status": True, # 是否启用重排序
  80. "rerank_model_name": "bce_rerank_model"
  81. }),
  82. "temperature": 0.6,
  83. "topP": 0.7,
  84. "maxToken": 4096,
  85. "enable_think": False,
  86. "prompt":
  87. """
  88. 你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{知识}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:“进入登录页面(图1.1)”需要从文字中删除图序,回复效果为:“进入登录页面”;“如图所示1.1”,回复效果为:“如图所示”;\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:“我想”,“我认为”,“think”等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{用户}
  89. """
  90. # "请根据以下知识内容回答用户的问题。\n\n知识内容:\n{知识}\n\n用户问题:\n{用户}\n\n请给出准确、详细的回答:",
  91. }
  92. async def get_rag_response(self, question: str) -> Dict[str, Any]:
  93. """
  94. 通过 RAG 系统获取问题的答案和上下文
  95. Args:
  96. question: 用户问题
  97. Returns:
  98. 包含 answer 和 contexts 的字典
  99. """
  100. # 构建 RAG 请求参数
  101. chat_json = self.rag_config.copy()
  102. chat_json["query"] = question
  103. # 创建 RAG 检索器
  104. rag_retriever = ChatRetrieverRag(chat_json=chat_json)
  105. # 获取检索结果
  106. retriever_result_list, search_doc_id_to_knowledge_id_dict = rag_retriever.retriever_result(chat_json)
  107. # 解析检索结果,获取上下文内容
  108. chunk_content, knowledge_info_dict = rag_retriever.parse_retriever_list(
  109. retriever_result_list,
  110. search_doc_id_to_knowledge_id_dict
  111. )
  112. # 提取 contexts(每个检索到的切片内容)
  113. contexts = [item["content"] for item in retriever_result_list]
  114. # 生成答案
  115. answer = ""
  116. async for event in rag_retriever.generate_rag_response(chat_json, chunk_content):
  117. if event.get("event") == "add":
  118. answer = event.get("data", "")
  119. elif event.get("event") == "finish":
  120. break
  121. return {
  122. "answer": answer,
  123. "contexts": contexts,
  124. "retriever_count": len(retriever_result_list)
  125. }
  126. async def process_qa_data(self, qa_data: List[Dict[str, str]]) -> List[Dict[str, Any]]:
  127. """
  128. 处理问答数据,通过 RAG 获取 contexts 和 answer
  129. Args:
  130. qa_data: 包含 question 和 ground_truth 的列表
  131. Returns:
  132. 完整的评估数据集(包含 question, contexts, answer, ground_truth)
  133. """
  134. full_data = []
  135. for idx, item in enumerate(qa_data):
  136. question = item.get("question", "")
  137. ground_truth = item.get("ground_truth", "")
  138. print(f"\n处理第 {idx + 1}/{len(qa_data)} 个问题: {question[:50]}...")
  139. try:
  140. # 通过 RAG 获取答案和上下文
  141. rag_result = await self.get_rag_response(question)
  142. full_item = {
  143. "question": question,
  144. "contexts": rag_result["contexts"],
  145. "answer": rag_result["answer"],
  146. "ground_truth": ground_truth,
  147. }
  148. full_data.append(full_item)
  149. print(f" ✓ 检索到 {rag_result['retriever_count']} 个上下文")
  150. print(f" ✓ 生成答案长度: {len(rag_result['answer'])} 字符")
  151. except Exception as e:
  152. print(f" ✗ 处理失败: {e}")
  153. # 失败时使用空值
  154. full_data.append({
  155. "question": question,
  156. "contexts": [],
  157. "answer": "",
  158. "ground_truth": ground_truth,
  159. })
  160. return full_data
  161. def load_qa_json(json_path: str) -> List[Dict[str, str]]:
  162. """
  163. 加载只包含 question 和 ground_truth 的 JSON 文件
  164. Args:
  165. json_path: JSON 文件路径
  166. Returns:
  167. 问答数据列表
  168. """
  169. try:
  170. with open(json_path, 'r', encoding='utf-8') as f:
  171. data_list = json.load(f)
  172. # 验证数据格式
  173. if not isinstance(data_list, list):
  174. raise ValueError("JSON 文件必须是一个数组")
  175. for item in data_list:
  176. if "question" not in item or "ground_truth" not in item:
  177. raise ValueError("每个数据项必须包含 'question' 和 'ground_truth' 字段")
  178. return data_list
  179. except Exception as e:
  180. print(f"加载 JSON 文件失败: {e}")
  181. raise
  182. def save_full_dataset(data: List[Dict[str, Any]], output_path: str):
  183. """
  184. 保存完整的数据集(包含 RAG 获取的 contexts 和 answer)
  185. Args:
  186. data: 完整数据集
  187. output_path: 输出文件路径
  188. """
  189. try:
  190. with open(output_path, 'w', encoding='utf-8') as f:
  191. json.dump(data, f, ensure_ascii=False, indent=2)
  192. print(f"\n完整数据集已保存到: {output_path}")
  193. except Exception as e:
  194. print(f"保存数据集失败: {e}")
  195. def run_evaluation(dataset: Dataset) -> Dict[str, Any]:
  196. """
  197. 执行 RAGAS 评估
  198. Args:
  199. dataset: 评估数据集
  200. Returns:
  201. 评估结果
  202. """
  203. print("\n" + "="*60)
  204. print("开始 RAGAS 评估...")
  205. print("="*60)
  206. # 确定要使用的指标
  207. final_metrics = list(metrics_to_run)
  208. if 'ground_truth' not in dataset.column_names:
  209. print("警告: 缺少 'ground_truth',将跳过 'context_recall' 指标")
  210. final_metrics.remove(context_recall)
  211. # 执行评估
  212. print(f"使用指标: {[m.name for m in final_metrics]}")
  213. print(f"数据集大小: {len(dataset)} 条")
  214. print(f"超时设置: 300 秒")
  215. print(f"重试次数: 3 次")
  216. print()
  217. try:
  218. result = evaluate(
  219. dataset=dataset,
  220. metrics=final_metrics,
  221. llm=vllm_generator,
  222. embeddings=vllm_embeddings,
  223. raise_exceptions=False, # 不因单个样本失败而中断
  224. show_progress=True, # 显示进度
  225. )
  226. except Exception as e:
  227. print(f"\n评估过程中出现错误: {e}")
  228. print("建议:检查 VLLM 服务是否正常运行")
  229. raise
  230. # 转换为 DataFrame
  231. result_df = result.to_pandas()
  232. # 计算平均指标(跳过 NaN 值)并统计 NaN 情况
  233. summary = {}
  234. nan_stats = {}
  235. print("\n" + "="*60)
  236. print("NaN 值统计:")
  237. print("="*60)
  238. for col in result_df.columns:
  239. if col not in ['question', 'contexts', 'answer', 'ground_truth', 'user_input', 'retrieved_contexts', 'response', 'reference']:
  240. try:
  241. # # 统计 NaN 数量
  242. # nan_count = result_df[col].isna().sum()
  243. # total_count = len(result_df)
  244. # nan_stats[col] = {
  245. # 'nan_count': int(nan_count),
  246. # 'total_count': total_count,
  247. # 'nan_ratio': float(nan_count / total_count) if total_count > 0 else 0
  248. # }
  249. # # 打印统计信息
  250. # if nan_count > 0:
  251. # print(f"⚠️ {col}: {nan_count}/{total_count} 个 NaN ({nan_count/total_count*100:.1f}%)")
  252. # else:
  253. # print(f"✅ {col}: 无 NaN 值")
  254. # 使用 nanmean 跳过 NaN 值
  255. mean_val = result_df[col].mean(skipna=True)
  256. if pd.notna(mean_val):
  257. summary[col] = float(mean_val)
  258. else:
  259. summary[col] = None
  260. print(f" 警告: 指标 '{col}' 全部为 NaN,无法计算平均值")
  261. except Exception as e:
  262. print(f" 错误: 无法处理指标 '{col}': {e}")
  263. pass
  264. print("\n" + "="*60)
  265. print("评估完成!")
  266. print("="*60)
  267. return {
  268. "rows": result_df.to_dict(orient="records"),
  269. "summary": summary,
  270. "metrics": [m.name for m in final_metrics],
  271. "count": len(result_df),
  272. }
  273. async def main_async(input_json: str, knowledge_ids: List[str], save_full_data: bool = True):
  274. """
  275. 主函数(异步)
  276. Args:
  277. input_json: 输入 JSON 文件路径(只包含 question 和 ground_truth)
  278. knowledge_ids: 知识库 ID 列表
  279. save_full_data: 是否保存完整数据集
  280. """
  281. print("="*60)
  282. print("RAGAS 评估脚本 - 集成 RAG 检索")
  283. print("="*60)
  284. # 1. 加载问答数据
  285. print(f"\n步骤 1: 加载问答数据从 {input_json}")
  286. qa_data = load_qa_json(input_json)
  287. print(f" ✓ 成功加载 {len(qa_data)} 条问答数据")
  288. # 2. 通过 RAG 获取 contexts 和 answer
  289. print(f"\n步骤 2: 通过 RAG 系统获取上下文和答案")
  290. print(f" 使用知识库: {knowledge_ids}")
  291. evaluator = RAGEvaluator(knowledge_ids=knowledge_ids)
  292. full_data = await evaluator.process_qa_data(qa_data)
  293. # 3. 保存完整数据集(可选)
  294. if save_full_data:
  295. output_path = input_json.replace(".json", "_full.json")
  296. save_full_dataset(full_data, output_path)
  297. # 4. 转换为 Dataset 格式
  298. print(f"\n步骤 3: 准备评估数据集")
  299. df = pd.DataFrame(full_data)
  300. dataset = Dataset.from_pandas(df)
  301. print(f" ✓ 数据集准备完成,共 {len(dataset)} 条数据")
  302. # 5. 执行评估
  303. print(f"\n步骤 4: 执行 RAGAS 评估")
  304. result = run_evaluation(dataset)
  305. # 6. 输出结果
  306. print("\n" + "="*60)
  307. print("评估结果汇总")
  308. print("="*60)
  309. for metric, value in result["summary"].items():
  310. print(f" {metric}: {value:.4f}")
  311. # 保存评估结果
  312. result_path = input_json.replace(".json", "_result.json")
  313. with open(result_path, 'w', encoding='utf-8') as f:
  314. json.dump(result, f, ensure_ascii=False, indent=2)
  315. print(f"\n评估结果已保存到: {result_path}")
  316. return result
  317. def main():
  318. """主函数入口"""
  319. # 配置参数
  320. input_json_file = "dataset_qa.json" # 输入文件(只包含 question 和 ground_truth)
  321. knowledge_ids = ["a2963496869283893248"] # 需要配置实际的知识库 ID
  322. # 检查输入文件
  323. if not os.path.exists(input_json_file):
  324. print(f"错误: 找不到输入文件 {input_json_file}")
  325. print("\n请创建一个 JSON 文件,格式如下:")
  326. print("""
  327. [
  328. {
  329. "question": "谁发明了电话?",
  330. "ground_truth": "电话的主要发明者是亚历山大·格拉汉姆·贝尔,他于1876年获得了电话的发明专利。"
  331. },
  332. {
  333. "question": "地球上最大的哺乳动物是什么?",
  334. "ground_truth": "地球上最大的哺乳动物是蓝鲸,成年蓝鲸体长可达30米,重量可达180吨。"
  335. }
  336. ]
  337. """)
  338. return
  339. # 运行异步主函数
  340. asyncio.run(main_async(input_json_file, knowledge_ids, save_full_data=True))
  341. if __name__ == "__main__":
  342. main()