| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407 |
- """
- RAGAS评估脚本 - 集成RAG检索功能 - FastAPI 版本
- 基于 ragas_eval_with_rag.py 创建的 API 服务
- 功能:
- - 上传 JSON 文件(只需 question 和 ground_truth)
- - 自动通过 RAG 获取 contexts 和 answer
- - 使用全部 5 个 RAGAS 指标进行评估
- - 返回详细的评估结果和统计
- """
- import os
- import sys
- import json
- import pandas as pd
- import asyncio
- import math
- from datasets import Dataset
- from typing import List, Dict, Any
- # FastAPI 相关
- from fastapi import FastAPI, UploadFile, File, HTTPException, Form
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.responses import FileResponse, JSONResponse
- from fastapi.staticfiles import StaticFiles
- import uvicorn
- # 添加项目根目录到 Python 路径
- project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- sys.path.insert(0, project_root)
- # RAGAS 评估相关
- from ragas import evaluate
- from ragas.metrics import (
- faithfulness,
- answer_correctness,
- answer_relevancy,
- context_precision,
- context_recall,
- NoiseSensitivity
- )
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
- # RAG 相关导入
- from rag.chat_message import ChatRetrieverRag
- from rag.llm import VllmApi
- # 配置 VLLM 和 Embedding 服务
- VLLM_LLM_BASE = "http://xia0miduo.gicp.net:8102/v1"
- VLLM_LLM_KEY = "vllm-dummy-key"
- VLLM_EMBEDDING_BASE = "http://10.168.100.17:8787/v1/"
- # 初始化 RAGAS 评估使用的模型
- vllm_generator = ChatOpenAI(
- model="Qwen3-Coder-30B-loft",
- base_url=VLLM_LLM_BASE,
- api_key=VLLM_LLM_KEY,
- temperature=0.1,
- max_tokens=51200,
- timeout=3000,
- request_timeout=3000,
- max_retries=5,
- )
- vllm_embeddings = OpenAIEmbeddings(
- model="",
- base_url=VLLM_EMBEDDING_BASE,
- api_key=VLLM_LLM_KEY,
- )
- # 定义评估指标
- metrics_to_run = [
- faithfulness,
- answer_correctness,
- answer_relevancy,
- context_precision,
- context_recall,
- NoiseSensitivity(llm=vllm_generator)
- ]
- class RAGEvaluator:
- """集成 RAG 功能的评估器"""
-
- def __init__(
- self,
- knowledge_ids: List[str],
- embedding_id: str = "e5",
- temperature: float = 0.6,
- top_p: float = 0.7,
- max_tokens: int = 4096,
- model: str = "Qwen3-Coder-30B-loft",
- enable_think: bool = False,
- slice_count: int = 5
- ):
- """初始化评估器
-
- Args:
- knowledge_ids: 知识库 ID 列表
- embedding_id: 嵌入模型 ID
- temperature: LLM 温度参数
- top_p: LLM top_p 参数
- max_tokens: 最大生成 token 数
- model: LLM 模型名称
- enable_think: 是否启用思考模式
- slice_count: 检索切片数量
- """
- self.knowledge_ids = knowledge_ids
- self.embedding_id = embedding_id
- self.temperature = temperature
- self.top_p = top_p
- self.max_tokens = max_tokens
- self.model = model
- self.enable_think = enable_think
- self.slice_count = slice_count
-
- self.rag_config = {
- "knowledgeIds": knowledge_ids,
- "embeddingId": embedding_id,
- "sliceCount": slice_count,
- "knowledgeInfo": json.dumps({
- "recall_method": "mixed",
- "rerank_status": True,
- "rerank_model_name": "bce_rerank_model"
- }),
- "temperature": temperature,
- "topP": top_p,
- "maxToken": max_tokens,
- "enable_think": enable_think,
- "prompt": """你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{知识}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:"进入登录页面(图1.1)"需要从文字中删除图序,回复效果为:"进入登录页面";"如图所示1.1",回复效果为:"如图所示";\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:"我想","我认为","think"等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{用户}"""
- }
-
- async def get_rag_response(self, question: str) -> Dict[str, Any]:
- """通过 RAG 系统获取答案和上下文"""
- chat_json = self.rag_config.copy()
- chat_json["query"] = question
-
- rag_retriever = ChatRetrieverRag(chat_json=chat_json)
- retriever_result_list, search_doc_id_to_knowledge_id_dict = rag_retriever.retriever_result(chat_json)
- chunk_content, knowledge_info_dict = rag_retriever.parse_retriever_list(
- retriever_result_list,
- search_doc_id_to_knowledge_id_dict
- )
-
- contexts = [item["content"] for item in retriever_result_list]
-
- answer = ""
- async for event in rag_retriever.generate_rag_response(chat_json, chunk_content):
- if event.get("event") == "add":
- answer = event.get("data", "")
- elif event.get("event") == "finish":
- break
-
- return {
- "answer": answer,
- "contexts": contexts,
- "retriever_count": len(retriever_result_list)
- }
-
- async def process_qa_data(self, qa_data: List[Dict[str, str]]) -> List[Dict[str, Any]]:
- """处理问答数据,通过 RAG 获取 contexts 和 answer"""
- full_data = []
-
- for idx, item in enumerate(qa_data):
- question = item.get("question", "")
- ground_truth = item.get("ground_truth", "")
-
- print(f"处理第 {idx + 1}/{len(qa_data)} 个问题: {question[:50]}...")
-
- try:
- rag_result = await self.get_rag_response(question)
-
- full_item = {
- "question": question,
- "contexts": rag_result["contexts"],
- "answer": rag_result["answer"],
- "ground_truth": ground_truth,
- }
-
- full_data.append(full_item)
- print(f" ✓ 检索到 {rag_result['retriever_count']} 个上下文")
-
- except Exception as e:
- print(f" ✗ 处理失败: {e}")
- full_data.append({
- "question": question,
- "contexts": [],
- "answer": "",
- "ground_truth": ground_truth,
- })
-
- return full_data
- def _clean_float_values(obj):
- """清理非法浮点数值(NaN, Infinity)为 None"""
- if isinstance(obj, float):
- if math.isnan(obj) or math.isinf(obj):
- return None
- return obj
- elif isinstance(obj, dict):
- return {k: _clean_float_values(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [_clean_float_values(item) for item in obj]
- return obj
- def run_evaluation(dataset: Dataset) -> Dict[str, Any]:
- """执行 RAGAS 评估"""
- print("\n" + "="*60)
- print("开始 RAGAS 评估...")
- print("="*60)
-
- final_metrics = list(metrics_to_run)
- if 'ground_truth' not in dataset.column_names:
- print("警告: 缺少 'ground_truth',将跳过 'context_recall' 指标")
- final_metrics.remove(context_recall)
-
- print(f"使用指标: {[m.name for m in final_metrics]}")
- print(f"数据集大小: {len(dataset)} 条\n")
-
- try:
- result = evaluate(
- dataset=dataset,
- metrics=final_metrics,
- llm=vllm_generator,
- embeddings=vllm_embeddings,
- raise_exceptions=False,
- show_progress=True,
- )
- except Exception as e:
- print(f"\n评估过程中出现错误: {e}")
- raise
-
- result_df = result.to_pandas()
-
- # 替换 DataFrame 中的 NaN 和 Infinity 为 None
- result_df = result_df.replace([float('inf'), float('-inf')], None)
- result_df = result_df.where(pd.notna(result_df), None)
-
- # 计算平均指标(仅计算真正的评估指标)
- metric_names = [m.name for m in final_metrics]
- summary = {}
- for metric_name in metric_names:
- if metric_name in result_df.columns:
- try:
- mean_val = result_df[metric_name].mean(skipna=True)
- if pd.notna(mean_val) and math.isfinite(mean_val):
- summary[metric_name] = float(mean_val)
- else:
- summary[metric_name] = None
- except Exception:
- summary[metric_name] = None
-
- print("\n" + "="*60)
- print("评估完成!")
- print("="*60)
-
- # 转换为字典并清理
- rows = result_df.to_dict(orient="records")
-
- return {
- "rows": _clean_float_values(rows), # 清理 NaN 和 Infinity, rows是评估结果的详细信息
- "summary": _clean_float_values(summary), # 清理 NaN 和 Infinity, summary是评估结果的汇总信息
- "metrics": [m.name for m in final_metrics], # final_metrics是评估指标
- "count": len(result_df), # 评估指标数
- }
- # FastAPI 应用
- app = FastAPI(
- title="RAGAS RAG 评估 API",
- description="集成 RAG 检索功能的 RAGAS 评估服务",
- version="1.0.0"
- )
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # 挂载前端静态文件
- frontend_dir = os.path.join(os.path.dirname(__file__), "frontend")
- if os.path.isdir(frontend_dir):
- app.mount("/static", StaticFiles(directory=frontend_dir), name="static")
- @app.get("/")
- def index():
- """首页"""
- # 优先使用 rag_eval.html
- rag_eval_path = os.path.join(frontend_dir, "rag_eval.html")
- if os.path.isfile(rag_eval_path):
- return FileResponse(rag_eval_path)
-
- # 备用 index.html
- index_path = os.path.join(frontend_dir, "index.html")
- if os.path.isfile(index_path):
- return FileResponse(index_path)
-
- return JSONResponse({
- "message": "RAGAS RAG 评估 API",
- "docs": "/docs",
- "version": "1.0.0"
- })
- @app.post("/api/evaluate")
- async def api_evaluate(
- file: UploadFile = File(...),
- knowledge_ids: str = Form(...),
- embedding_id: str = Form("e5"),
- temperature: float = Form(0.6),
- top_p: float = Form(0.7),
- max_tokens: int = Form(4096),
- model: str = Form("Qwen3-Coder-30B-loft"),
- enable_think: bool = Form(False),
- slice_count: int = Form(5)
- ):
- """
- 评估接口
-
- 参数:
- - file: JSON 文件(包含 question 和 ground_truth)
- - knowledge_ids: 知识库 ID,多个用逗号分隔
- - embedding_id: 嵌入模型 ID(默认 e5)
- - temperature: LLM 温度参数(默认 0.6)
- - top_p: LLM top_p 参数(默认 0.7)
- - max_tokens: 最大生成 token 数(默认 4096)
- - model: LLM 模型名称(默认 Qwen3-Coder-30B-loft)
- - enable_think: 是否启用思考模式(默认 False)
- - slice_count: 检索切片数量(默认 5)
- """
- if not file.filename.lower().endswith(".json"):
- raise HTTPException(status_code=400, detail="仅支持 .json 文件")
-
- try:
- # 解析知识库 ID
- knowledge_id_list = [kid.strip() for kid in knowledge_ids.split(",") if kid.strip()]
- if not knowledge_id_list:
- raise ValueError("knowledge_ids 不能为空")
-
- # 读取文件
- content = await file.read()
- qa_data = json.loads(content.decode("utf-8"))
-
- if not isinstance(qa_data, list):
- raise ValueError("JSON 必须是数组格式")
-
- if not qa_data:
- raise ValueError("数据不能为空")
-
- # 初始化评估器
- print(f"\n配置: 模型={model}, 温度={temperature}, top_p={top_p}, 切片数={slice_count}")
- evaluator = RAGEvaluator(
- knowledge_ids=knowledge_id_list,
- embedding_id=embedding_id,
- temperature=temperature,
- top_p=top_p,
- max_tokens=max_tokens,
- model=model,
- enable_think=enable_think,
- slice_count=slice_count
- )
-
- # 处理数据
- print(f"\n开始处理 {len(qa_data)} 条数据...")
- full_data = await evaluator.process_qa_data(qa_data)
-
- # 创建 Dataset
- dataset = Dataset.from_dict({
- "question": [item["question"] for item in full_data],
- "contexts": [item["contexts"] for item in full_data],
- "answer": [item["answer"] for item in full_data],
- "ground_truth": [item["ground_truth"] for item in full_data],
- })
-
- # 执行评估
- result = run_evaluation(dataset)
-
- return JSONResponse(result)
-
- except json.JSONDecodeError as e:
- raise HTTPException(status_code=400, detail=f"JSON 解析失败: {e}")
- except Exception as e:
- import traceback
- error_detail = f"评估失败: {e}\n{traceback.format_exc()}"
- print(error_detail)
- raise HTTPException(status_code=500, detail=str(e))
- if __name__ == "__main__":
- port = int(os.environ.get("PORT", "8001"))
-
- print(f"启动 RAGAS RAG 评估 API 服务...")
- print(f"访问地址: http://0.0.0.0:{port}")
- print(f"API 文档: http://0.0.0.0:{port}/docs")
-
- uvicorn.run(
- "ragas_eval_with_rag_api:app",
- host="0.0.0.0",
- port=port,
- reload=False
- )
|