dome.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import os
  2. import json
  3. import pandas as pd
  4. from datasets import Dataset
  5. from ragas import evaluate
  6. from ragas.metrics import (
  7. faithfulness,
  8. answer_correctness,
  9. answer_relevancy,
  10. context_precision,
  11. context_recall,
  12. NoiseSensitivity
  13. )
  14. from langchain_openai import ChatOpenAI, OpenAIEmbeddings
  15. VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
  16. VLLM_LLM_KEY = "vllm-dummy-key"
  17. VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
  18. vllm_generator = ChatOpenAI(
  19. model="Qwen3-Coder-30B-loft",
  20. base_url=VLLM_LLM_BASE,
  21. api_key=VLLM_LLM_KEY,
  22. temperature=0,
  23. max_tokens=512,
  24. )
  25. vllm_embeddings = OpenAIEmbeddings(
  26. model="",
  27. base_url=VLLM_EMBEDDING_BASE,
  28. api_key=VLLM_LLM_KEY,
  29. )
  30. # --- 2. 加载和准备数据集 ---
  31. def load_data_from_json(json_path):
  32. """
  33. 从 JSON 文件加载数据并转换为 datasets.Dataset 格式。
  34. 假设的 JSON 结构:
  35. [
  36. {
  37. "question": "...",
  38. "answer": "...",
  39. "contexts": ["...", "..."],
  40. "ground_truth": "..."
  41. },
  42. ...
  43. ]
  44. """
  45. try:
  46. with open(json_path, 'r', encoding='utf-8') as f:
  47. data_list = json.load(f)
  48. df = pd.DataFrame(data_list)
  49. required_cols = {'question', 'answer', 'contexts'}
  50. if not required_cols.issubset(df.columns):
  51. raise ValueError(f"JSON 文件必须包含 'question', 'answer', 和 'contexts' 键")
  52. if 'ground_truth' not in df.columns:
  53. print("警告: 未找到 'ground_truth' 列。'context_recall' 指标将无法计算。")
  54. dataset = Dataset.from_pandas(df)
  55. return dataset
  56. except Exception as e:
  57. print(f"加载数据时出错: {e}")
  58. return None
  59. # --- 3. 定义指标 ---
  60. metrics_to_run = [
  61. faithfulness,
  62. answer_correctness,
  63. answer_relevancy,
  64. context_precision,
  65. context_recall,
  66. NoiseSensitivity(llm=vllm_generator),
  67. ]
  68. # --- 4. 执行评估 ---
  69. def run_evaluation(dataset_path):
  70. print(f"正在从 {dataset_path} 加载数据...")
  71. dataset = load_data_from_json(dataset_path)
  72. if dataset is None:
  73. print("数据加载失败,退出评估。")
  74. return
  75. print(f"成功加载 {len(dataset)} 条评估数据。")
  76. final_metrics = list(metrics_to_run)
  77. if 'ground_truth' not in dataset.column_names:
  78. print("由于缺少 'ground_truth',将跳过 'context_recall' 指标。")
  79. final_metrics.remove(context_recall)
  80. if not final_metrics:
  81. print("没有可执行的指标。退出。")
  82. return
  83. print("开始 RAGAS 评估... (这可能需要一些时间)")
  84. # 核心:调用 evaluate
  85. result = evaluate(
  86. dataset=dataset,
  87. metrics=final_metrics,
  88. llm=vllm_generator,
  89. embeddings=vllm_embeddings
  90. )
  91. print("评估完成!")
  92. result_df = result.to_pandas()
  93. print(json.dumps(result_df.to_dict(orient="records"), indent=4, ensure_ascii=False))
  94. # --- 主程序入口 ---
  95. if __name__ == "__main__":
  96. input_json_file = "dataset.json"
  97. if not os.path.exists(input_json_file):
  98. print(f"错误: 找不到 {input_json_file}")
  99. print("请创建一个示例 JSON 文件,结构如下:")
  100. print("""
  101. [
  102. {
  103. "question": "什么是 RAGAS?",
  104. "answer": "RAGAS 是一个评估 RAG 管道的框架。",
  105. "contexts": ["RAGAS (Retrieval-Augmented Generation Assessment) 是一个用于评估 RAG 管道的框架。"],
  106. "ground_truth": "RAGAS 是一个专为评估 RAG 管道设计的框架,它关注检索和生成的质量。"
  107. }
  108. ]
  109. """)
  110. else:
  111. run_evaluation(input_json_file)