| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- import os
- import json
- import pandas as pd
- from datasets import Dataset
- from typing import List, Dict, Any
- # FastAPI 相关
- from fastapi import FastAPI, UploadFile, File, HTTPException
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, JSONResponse
- from fastapi.staticfiles import StaticFiles
- import uvicorn
- os.environ["RUN_SERVER"] = "True"
- from ragas import evaluate
- from ragas.metrics import (
- faithfulness,
- answer_correctness,
- answer_relevancy,
- context_precision,
- context_recall,
- )
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
- VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
- VLLM_LLM_KEY = "vllm-dummy-key"
- VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
- vllm_generator = ChatOpenAI(
- model="Qwen3-Coder-30B-loft",
- base_url=VLLM_LLM_BASE,
- api_key=VLLM_LLM_KEY,
- temperature=0,
- max_tokens=512,
- )
- vllm_embeddings = OpenAIEmbeddings(
- model="",
- base_url=VLLM_EMBEDDING_BASE,
- api_key=VLLM_LLM_KEY,
- )
- # --- 2. 加载和准备数据集 ---
- def load_data_from_json(json_path):
- """
- 从 JSON 文件加载数据并转换为 datasets.Dataset 格式。
- 假设的 JSON 结构:
- [
- {
- "question": "...",
- "answer": "...",
- "contexts": ["...", "..."],
- "ground_truth": "..."
- },
- ...
- ]
- """
- try:
- with open(json_path, 'r', encoding='utf-8') as f:
- data_list = json.load(f)
-
- df = pd.DataFrame(data_list)
- required_cols = {'question', 'answer', 'contexts'}
- if not required_cols.issubset(df.columns):
- raise ValueError(f"JSON 文件必须包含 'question', 'answer', 和 'contexts' 键")
- if 'ground_truth' not in df.columns:
- print("警告: 未找到 'ground_truth' 列。'context_recall' 指标将无法计算。")
-
- dataset = Dataset.from_pandas(df)
- return dataset
- except Exception as e:
- print(f"加载数据时出错: {e}")
- return None
- def load_data_from_list(data_list: List[Dict[str, Any]]):
- """
- 从字典列表加载数据并转换为 datasets.Dataset。
- 该函数用于 FastAPI 文件上传的场景。
- """
- try:
- df = pd.DataFrame(data_list)
- required_cols = {'question', 'answer', 'contexts'}
- if not required_cols.issubset(df.columns):
- raise ValueError("JSON 必须包含 'question', 'answer', 'contexts' 键")
- if 'ground_truth' not in df.columns:
- print("警告: 未找到 'ground_truth' 列。'context_recall' 指标将无法计算。")
- dataset = Dataset.from_pandas(df)
- return dataset
- except Exception as e:
- raise ValueError(f"加载上传数据时出错: {e}")
- # --- 3. 定义指标 ---
- metrics_to_run = [
- faithfulness,
- answer_correctness,
- answer_relevancy,
- context_precision,
- context_recall,
- ]
- # --- 4. 执行评估 ---
- def run_evaluation(dataset_path):
- print(f"正在从 {dataset_path} 加载数据...")
- dataset = load_data_from_json(dataset_path)
- if dataset is None:
- print("数据加载失败,退出评估。")
- return
- print(f"成功加载 {len(dataset)} 条评估数据。")
- final_metrics = list(metrics_to_run)
- if 'ground_truth' not in dataset.column_names:
- print("由于缺少 'ground_truth',将跳过 'context_recall' 指标。")
- final_metrics.remove(context_recall)
- if not final_metrics:
- print("没有可执行的指标。退出。")
- return
- print("开始 RAGAS 评估... (这可能需要一些时间)")
-
- # 核心:调用 evaluate
- result = evaluate(
- dataset=dataset,
- metrics=final_metrics,
- llm=vllm_generator,
- embeddings=vllm_embeddings
- )
- print("评估完成!")
- result_df = result.to_pandas()
- print(json.dumps(result_df.to_dict(orient="records"), indent=4, ensure_ascii=False))
- def summarize_metrics(result_df: pd.DataFrame) -> Dict[str, float]:
- """
- 汇总各指标的平均值,便于前端展示。
- """
- summary = {}
- for col in [
- 'faithfulness',
- 'answer_correctness',
- 'answer_relevancy',
- 'context_precision',
- 'context_recall',
- ]:
- if col in result_df.columns:
- try:
- summary[col] = float(result_df[col].mean())
- except Exception:
- pass
- return summary
- def evaluate_dataset(dataset: Dataset, final_metrics: List = None) -> Dict[str, Any]:
- """
- 针对传入的 Dataset 执行评估,返回行结果与汇总。
- """
- if final_metrics is None:
- final_metrics = list(metrics_to_run)
- # 缺 ground_truth 时移除 context_recall
- if 'ground_truth' not in dataset.column_names and context_recall in final_metrics:
- final_metrics.remove(context_recall)
- result = evaluate(
- dataset=dataset,
- metrics=final_metrics,
- llm=vllm_generator,
- embeddings=vllm_embeddings,
- )
- result_df = result.to_pandas()
- print(json.dumps(result_df.to_dict(orient="records"), indent=4, ensure_ascii=False))
- return {
- "rows": result_df.to_dict(orient="records"),
- "summary": summarize_metrics(result_df),
- "metrics": [m.name for m in final_metrics],
- "count": len(result_df),
- }
- # --- FastAPI 应用 ---
- app = FastAPI()
- # 允许本地前端访问
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 挂载前端静态目录
- frontend_dir = os.path.join(os.path.dirname(__file__), "frontend")
- if os.path.isdir(frontend_dir):
- app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
- @app.get("/")
- def index():
- # 如果存在前端页面则返回
- index_path = os.path.join(frontend_dir, "index.html")
- if os.path.isfile(index_path):
- return FileResponse(index_path)
- return JSONResponse({"message": "前端未创建,请访问 /api/docs 使用接口。"})
- @app.post("/api/evaluate")
- async def api_evaluate(file: UploadFile = File(...)):
- """
- 上传一个 JSON 文件(数组,每项包含 question/answer/contexts/ground_truth 可选),执行评估。
- """
- if not file.filename.lower().endswith(".json"):
- raise HTTPException(status_code=400, detail="仅支持上传 .json 文件")
- try:
- content = await file.read()
- data_list = json.loads(content.decode("utf-8"))
- if not isinstance(data_list, list):
- raise ValueError("JSON 顶层必须为数组")
- dataset = load_data_from_list(data_list)
- # ground_truth 缺失时提示前端
- final_metrics = list(metrics_to_run)
- if 'ground_truth' not in dataset.column_names and context_recall in final_metrics:
- final_metrics.remove(context_recall)
- skipped = ["context_recall"]
- else:
- skipped = []
- result_payload = evaluate_dataset(dataset, final_metrics=final_metrics)
- result_payload["skipped_metrics"] = skipped
- # print(result_payload)
- return JSONResponse(result_payload)
- except Exception as e:
- raise HTTPException(status_code=400, detail=f"评估失败: {e}")
- # --- 主程序入口 ---
- if __name__ == "__main__":
- # 命令行模式:仍支持直接评估示例文件
- input_json_file = "dataset.json"
- # 如果设置了环境变量 RUN_SERVER=true,则启动服务
- run_server = os.environ.get("RUN_SERVER", "").lower() in {"1", "true", "yes"}
- if run_server:
- # 可通过环境变量 RELOAD 控制是否开启热重载,默认关闭以避免重载中断大型库导入
- reload_enabled = os.environ.get("RELOAD", "").lower() in {"1", "true", "yes"}
- uvicorn.run("ragas_eval:app", host="0.0.0.0", port=8000, reload=reload_enabled)
- else:
- if not os.path.exists(input_json_file):
- print(f"错误: 找不到 {input_json_file}")
- print("请创建一个示例 JSON 文件,结构如下:")
- print(
- """
- [
- {
- "question": "什么是 RAGAS?",
- "answer": "RAGAS 是一个评估 RAG 管道的框架。",
- "contexts": [
- "RAGAS (Retrieval-Augmented Generation Assessment) 是一个用于评估 RAG 管道的框架。"
- ],
- "ground_truth": "RAGAS 是一个专为评估 RAG 管道设计的框架,它关注检索和生成的质量。"
- }
- ]
- """
- )
- else:
- run_evaluation(input_json_file)
|