| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 简单的RAG项目RAGAS评估脚本
- 包含常用评估指标:答案相关性、上下文精确度、上下文召回率、答案正确性
- 使用DeepSeek API进行LLM评估
- """
- import json
- import asyncio
- import aiohttp
- import pandas as pd
- from typing import List, Dict, Any
- import logging
- from datetime import datetime
- import time
- import re
- from pathlib import Path
- from openai import AsyncOpenAI
- # 配置日志
- logging.basicConfig(
- level=logging.INFO,
- format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[
- logging.FileHandler('rag_evaluation.log', encoding='utf-8'),
- logging.StreamHandler()
- ]
- )
- logger = logging.getLogger(__name__)
- # DeepSeek API配置
- DEEPSEEK_API_KEY = "sk-72f9d0e5bc894e1d828d73bdcc50ff0a" # 请替换为实际的API Key
- DEEPSEEK_BASE_URL = "https://api.deepseek.com"
- class SimpleRAGEvaluator:
- def __init__(self):
- """
- 初始化RAG评估器
- """
- self.api_url = "http://localhost:6000/rag/chat"
- self.timeout = 60
- self.max_retries = 3
- self.retry_delay = 2
-
- # 初始化DeepSeek客户端
- self.deepseek_client = AsyncOpenAI(
- api_key=DEEPSEEK_API_KEY,
- base_url=DEEPSEEK_BASE_URL
- )
-
- # 默认请求参数
- self.default_request = {
- "appId": "2924812721300312064",
- "desc": "高井信息员工手册,出差管理制度等",
- "isDeepThink": "N",
- "knowledgeIds": [
- "a2963496869283893248",
- "a2963501316240183296"
- ],
- "knowledgeInfo": '{"param_desc":"strict","show_recall_result":true,"recall_method":"embedding","rerank_status":true,"rerank_model_name":"rerank","slice_config_type":"customized","rerank_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}],"recall_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}]}',
- "maxToken": 8192,
- "model": "Qwen3-30B",
- "name": "高井信息公司管理制度",
- "params": {},
- "prompt": "你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{{知识}}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:\"进入登录页面(图1.1)\"需要从文字中删除图序,回复效果为:\"进入登录页面\";\"如图所示1.1\",回复效果为:\"如图所示\";\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:\"我想\",\"我认为\",\"think\"等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{{用户}}",
- "status": "3",
- "temperature": "0.01",
- "topP": "0.5",
- "typeId": 40,
- "updateBy": "1",
- "visible": "0",
- "embeddingId": "multilingual-e5-large-instruct",
- "enable_think": False
- }
-
- self.results = []
- self.stats = {
- 'total_questions': 0,
- 'successful_calls': 0,
- 'failed_calls': 0,
- 'avg_response_time': 0,
- 'response_times': []
- }
-
- def load_validation_data(self, file_path: str = "val_data.json") -> List[Dict[str, Any]]:
- """
- 加载验证数据
- """
- try:
- with open(file_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- logger.info(f"成功加载 {len(data)} 条验证数据")
- return data
- except Exception as e:
- logger.error(f"加载验证数据失败: {e}")
- raise
-
- async def call_rag_api(self, session: aiohttp.ClientSession, question: str) -> Dict[str, Any]:
- """
- 调用RAG API并处理流式响应
- """
- request_data = self.default_request.copy()
- request_data["query"] = question
-
- try:
- start_time = time.time()
-
- async with session.post(
- self.api_url,
- json=request_data,
- timeout=aiohttp.ClientTimeout(total=self.timeout),
- headers={'Accept': 'text/event-stream'}
- ) as response:
- if response.status != 200:
- raise Exception(f"API返回错误状态码: {response.status}")
-
- full_answer = ""
- chat_id = ""
- contexts_from_api = []
-
- # 处理流式响应
- async for line in response.content:
- line = line.decode('utf-8').strip()
- if not line or line.startswith(':'):
- continue
-
- if line.startswith('data: '):
- data_str = line[6:] # 移除 'data: ' 前缀
- if data_str == '[DONE]':
- break
-
- # 尝试解析JSON格式
- try:
- event_data = json.loads(data_str)
-
- # 确保event_data是字典类型才进行处理
- if isinstance(event_data, dict):
- if event_data.get('event') == 'add':
- # 如果是流式输出包含完整内容,直接替换而不是累积
- chunk_data = event_data.get('data', '')
- if isinstance(chunk_data, str):
- full_answer = chunk_data # 直接替换,因为流式输出包含完整内容
- chat_id = event_data.get('id', '') or chat_id
-
- # 提取上下文信息
- if 'contexts' in event_data:
- if isinstance(event_data['contexts'], list):
- contexts_from_api.extend(event_data['contexts'])
-
- elif event_data.get('event') == 'finish':
- break
-
- # 也检查直接的content和contexts字段(兼容不同的响应格式)
- if 'content' in event_data:
- full_answer = event_data['content']
- if 'contexts' in event_data:
- if isinstance(event_data['contexts'], list):
- contexts_from_api.extend(event_data['contexts'])
- else:
- logger.debug(f"跳过非字典类型的数据: {type(event_data)}")
-
- except json.JSONDecodeError:
- # 如果不是JSON格式,直接作为文本内容处理
- if data_str and data_str != '[DONE]':
- full_answer = data_str # 直接替换,因为流式输出包含完整内容
- continue
- except Exception as pe:
- logger.warning(f"处理流式数据失败: {pe}")
- continue
-
- end_time = time.time()
- response_time = end_time - start_time
- self.stats['response_times'].append(response_time)
-
- logger.info(f"API调用成功,答案长度: {len(full_answer)}, chat_id: {chat_id}, 上下文数量: {len(contexts_from_api)}")
-
- # 调试信息:显示提取到的上下文
- if contexts_from_api:
- logger.info(f"成功提取到 {len(contexts_from_api)} 个上下文片段")
- for i, ctx in enumerate(contexts_from_api[:2], 1): # 只显示前2个
- logger.info(f"上下文[{i}]: {ctx[:100]}{'...' if len(ctx) > 100 else ''}")
- else:
- logger.warning("未从API响应中提取到上下文信息,将使用验证数据中的预设上下文")
-
- return {
- 'answer': full_answer,
- 'chat_id': chat_id,
- 'contexts': contexts_from_api,
- 'response_time': response_time,
- 'success': True
- }
-
- except Exception as e:
- logger.error(f"API调用失败: {e}")
- return {
- 'answer': '',
- 'chat_id': '',
- 'response_time': 0,
- 'success': False,
- 'error': str(e)
- }
-
- async def calculate_answer_relevancy(self, question: str, answer: str) -> float:
- """
- 使用DeepSeek API计算答案相关性
- """
- if not answer.strip():
- return 0.0
-
- prompt = f"""
- 请评估以下答案对问题的相关性,给出0-1之间的分数(保留3位小数)。
- 问题:{question}
- 答案:{answer}
- 评估标准:
- - 1.0:答案完全回答了问题,高度相关
- - 0.8:答案大部分回答了问题,相关性较高
- - 0.6:答案部分回答了问题,有一定相关性
- - 0.4:答案与问题有关但不够直接
- - 0.2:答案与问题关系较弱
- - 0.0:答案与问题无关
- 请只返回数字分数,不要其他解释。
- """
-
- try:
- response = await self.deepseek_client.chat.completions.create(
- model="deepseek-chat",
- messages=[{"role": "user", "content": prompt}],
- temperature=0.1,
- max_tokens=10
- )
-
- score_text = response.choices[0].message.content.strip()
- score = float(re.findall(r'\d+\.\d+|\d+', score_text)[0])
- return min(max(score, 0.0), 1.0)
-
- except Exception as e:
- logger.warning(f"DeepSeek API调用失败,使用简化计算: {e}")
- # 回退到简化计算
- question_words = set(re.findall(r'\b\w+\b', question.lower()))
- answer_words = set(re.findall(r'\b\w+\b', answer.lower()))
- if len(question_words) == 0:
- return 0.0
- overlap = len(question_words.intersection(answer_words))
- return min(overlap / len(question_words), 1.0)
-
- async def calculate_context_precision(self, contexts: List[str], answer: str) -> float:
- """
- 使用DeepSeek API计算上下文精确度
- """
- if not contexts or not answer.strip():
- return 0.0
-
- contexts_text = "\n\n".join([f"上下文{i+1}: {ctx}" for i, ctx in enumerate(contexts)])
-
- prompt = f"""
- 请评估检索到的上下文对生成答案的精确度,给出0-1之间的分数(保留3位小数)。
- 检索到的上下文:
- {contexts_text}
- 生成的答案:
- {answer}
- 评估标准:
- - 1.0:所有上下文都与答案高度相关,没有无关信息
- - 0.8:大部分上下文与答案相关,少量无关信息
- - 0.6:一半以上上下文与答案相关
- - 0.4:部分上下文与答案相关
- - 0.2:少量上下文与答案相关
- - 0.0:上下文与答案基本无关
- 请只返回数字分数,不要其他解释。
- """
-
- try:
- response = await self.deepseek_client.chat.completions.create(
- model="deepseek-chat",
- messages=[{"role": "user", "content": prompt}],
- temperature=0.1,
- max_tokens=10
- )
-
- score_text = response.choices[0].message.content.strip()
- score = float(re.findall(r'\d+\.\d+|\d+', score_text)[0])
- return min(max(score, 0.0), 1.0)
-
- except Exception as e:
- logger.warning(f"DeepSeek API调用失败,使用简化计算: {e}")
- # 回退到简化计算
- answer_words = set(re.findall(r'\b\w+\b', answer.lower()))
- relevant_contexts = 0
- for context in contexts:
- context_words = set(re.findall(r'\b\w+\b', context.lower()))
- if len(context_words.intersection(answer_words)) > 0:
- relevant_contexts += 1
- return relevant_contexts / len(contexts) if contexts else 0.0
-
- async def calculate_context_recall(self, contexts: List[str], ground_truth: str) -> float:
- """
- 使用DeepSeek API计算上下文召回率
- """
- if not contexts or not ground_truth.strip():
- return 0.0
-
- contexts_text = "\n\n".join([f"上下文{i+1}: {ctx}" for i, ctx in enumerate(contexts)])
-
- prompt = f"""
- 请评估检索到的上下文对标准答案的召回率,给出0-1之间的分数(保留3位小数)。
- 检索到的上下文:
- {contexts_text}
- 标准答案:
- {ground_truth}
- 评估标准:
- - 1.0:上下文完全包含了标准答案的所有关键信息
- - 0.8:上下文包含了标准答案的大部分关键信息
- - 0.6:上下文包含了标准答案的一半以上关键信息
- - 0.4:上下文包含了标准答案的部分关键信息
- - 0.2:上下文包含了标准答案的少量关键信息
- - 0.0:上下文基本不包含标准答案的关键信息
- 请只返回数字分数,不要其他解释。
- """
-
- try:
- response = await self.deepseek_client.chat.completions.create(
- model="deepseek-chat",
- messages=[{"role": "user", "content": prompt}],
- temperature=0.1,
- max_tokens=10
- )
-
- score_text = response.choices[0].message.content.strip()
- score = float(re.findall(r'\d+\.\d+|\d+', score_text)[0])
- return min(max(score, 0.0), 1.0)
-
- except Exception as e:
- logger.warning(f"DeepSeek API调用失败,使用简化计算: {e}")
- # 回退到简化计算
- ground_truth_words = set(re.findall(r'\b\w+\b', ground_truth.lower()))
- all_context_words = set()
- for context in contexts:
- context_words = set(re.findall(r'\b\w+\b', context.lower()))
- all_context_words.update(context_words)
- if len(ground_truth_words) == 0:
- return 0.0
- overlap = len(ground_truth_words.intersection(all_context_words))
- return overlap / len(ground_truth_words)
-
- async def calculate_answer_correctness(self, answer: str, ground_truth: str) -> float:
- """
- 使用DeepSeek API计算答案正确性
- """
- if not answer.strip() or not ground_truth.strip():
- return 0.0
-
- prompt = f"""
- 请评估生成答案与标准答案的正确性,给出0-1之间的分数(保留3位小数)。
- 生成的答案:
- {answer}
- 标准答案:
- {ground_truth}
- 评估标准:
- - 1.0:答案在事实和语义上与标准答案完全一致
- - 0.8:答案在主要事实上正确,与标准答案高度一致
- - 0.6:答案在核心事实上正确,与标准答案较为一致
- - 0.4:答案部分正确,与标准答案有一定一致性
- - 0.2:答案少量正确,与标准答案一致性较低
- - 0.0:答案错误或与标准答案完全不一致
- 请只返回数字分数,不要其他解释。
- """
-
- try:
- response = await self.deepseek_client.chat.completions.create(
- model="deepseek-chat",
- messages=[{"role": "user", "content": prompt}],
- temperature=0.1,
- max_tokens=10
- )
-
- score_text = response.choices[0].message.content.strip()
- score = float(re.findall(r'\d+\.\d+|\d+', score_text)[0])
- return min(max(score, 0.0), 1.0)
-
- except Exception as e:
- logger.warning(f"DeepSeek API调用失败,使用简化计算: {e}")
- # 回退到简化计算
- answer_words = set(re.findall(r'\b\w+\b', answer.lower()))
- ground_truth_words = set(re.findall(r'\b\w+\b', ground_truth.lower()))
- if len(answer_words) == 0 and len(ground_truth_words) == 0:
- return 1.0
- if len(answer_words) == 0 or len(ground_truth_words) == 0:
- return 0.0
- intersection = len(answer_words.intersection(ground_truth_words))
- union = len(answer_words.union(ground_truth_words))
- return intersection / union if union > 0 else 0.0
-
- async def evaluate_single_item(self, session: aiohttp.ClientSession, item: Dict[str, Any]) -> Dict[str, Any]:
- """
- 评估单个问题
- """
- question = item['question']
- expected_answer = item['answer']
- contexts = item['contexts']
- ground_truth = item['ground_truth']
-
- logger.info(f"评估问题: {question[:50]}...")
-
- # 调用RAG API
- api_result = await self.call_rag_api(session, question)
-
- if not api_result['success']:
- self.stats['failed_calls'] += 1
- return {
- 'question': question,
- 'expected_answer': expected_answer,
- 'actual_answer': '',
- 'contexts': contexts,
- 'ground_truth': ground_truth,
- 'answer_relevancy': 0.0,
- 'context_precision': 0.0,
- 'context_recall': 0.0,
- 'answer_correctness': 0.0,
- 'response_time': 0,
- 'success': False,
- 'error': api_result.get('error', 'Unknown error')
- }
-
- actual_answer = api_result['answer']
- self.stats['successful_calls'] += 1
-
- # 使用API返回的上下文(如果有的话),否则使用验证数据中的上下文
- api_contexts = api_result.get('contexts', [])
- evaluation_contexts = api_contexts if api_contexts else contexts
-
- context_source = "API返回" if api_contexts else "验证数据预设"
- logger.info(f"使用上下文数量: {len(evaluation_contexts)} (来源: {context_source}), 答案长度: {len(actual_answer)}")
-
- # 显示实际使用的上下文(调试信息)
- if evaluation_contexts:
- logger.debug(f"评估使用的上下文片段:")
- for i, ctx in enumerate(evaluation_contexts[:2], 1):
- logger.debug(f" [{i}] {ctx[:80]}{'...' if len(ctx) > 80 else ''}")
-
- # 计算各项指标
- answer_relevancy = await self.calculate_answer_relevancy(question, actual_answer)
- context_precision = await self.calculate_context_precision(evaluation_contexts, actual_answer)
- context_recall = await self.calculate_context_recall(evaluation_contexts, ground_truth)
- answer_correctness = await self.calculate_answer_correctness(actual_answer, ground_truth)
-
- return {
- 'question': question,
- 'expected_answer': expected_answer,
- 'actual_answer': actual_answer,
- 'contexts': contexts,
- 'ground_truth': ground_truth,
- 'answer_relevancy': answer_relevancy,
- 'context_precision': context_precision,
- 'context_recall': context_recall,
- 'answer_correctness': answer_correctness,
- 'response_time': api_result['response_time'],
- 'success': True
- }
-
- async def run_evaluation(self, val_data_path: str = "val_data.json") -> Dict[str, Any]:
- """
- 运行完整评估
- """
- logger.info("开始RAG系统评估")
-
- # 加载验证数据
- val_data = self.load_validation_data(val_data_path)
- self.stats['total_questions'] = len(val_data)
-
- # 创建HTTP会话
- connector = aiohttp.TCPConnector(limit=10)
- timeout = aiohttp.ClientTimeout(total=self.timeout)
-
- async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
- # 批量处理
- tasks = []
- for item in val_data:
- task = self.evaluate_single_item(session, item)
- tasks.append(task)
-
- # 执行评估
- self.results = await asyncio.gather(*tasks, return_exceptions=True)
-
- # 过滤异常结果
- valid_results = [r for r in self.results if not isinstance(r, Exception)]
-
- # 计算总体指标
- metrics = self.calculate_overall_metrics(valid_results)
-
- logger.info("评估完成")
- return metrics
-
- def calculate_overall_metrics(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
- """
- 计算总体评估指标
- """
- if not results:
- return {}
-
- successful_results = [r for r in results if r.get('success', False)]
-
- if not successful_results:
- return {
- 'total_questions': len(results),
- 'successful_evaluations': 0,
- 'success_rate': 0.0,
- 'avg_answer_relevancy': 0.0,
- 'avg_context_precision': 0.0,
- 'avg_context_recall': 0.0,
- 'avg_answer_correctness': 0.0,
- 'avg_response_time': 0.0
- }
-
- # 计算平均指标
- avg_answer_relevancy = sum(r['answer_relevancy'] for r in successful_results) / len(successful_results)
- avg_context_precision = sum(r['context_precision'] for r in successful_results) / len(successful_results)
- avg_context_recall = sum(r['context_recall'] for r in successful_results) / len(successful_results)
- avg_answer_correctness = sum(r['answer_correctness'] for r in successful_results) / len(successful_results)
- avg_response_time = sum(r['response_time'] for r in successful_results) / len(successful_results)
-
- return {
- 'total_questions': len(results),
- 'successful_evaluations': len(successful_results),
- 'success_rate': len(successful_results) / len(results),
- 'avg_answer_relevancy': avg_answer_relevancy,
- 'avg_context_precision': avg_context_precision,
- 'avg_context_recall': avg_context_recall,
- 'avg_answer_correctness': avg_answer_correctness,
- 'avg_response_time': avg_response_time,
- 'detailed_results': successful_results
- }
-
- def save_results(self, metrics: Dict[str, Any], output_file: str = None):
- """
- 保存评估结果
- """
- if output_file is None:
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- output_file = f"rag_evaluation_results_{timestamp}.json"
-
- try:
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(metrics, f, ensure_ascii=False, indent=2)
- logger.info(f"评估结果已保存到: {output_file}")
- except Exception as e:
- logger.error(f"保存结果失败: {e}")
-
- def print_summary(self, metrics: Dict[str, Any]):
- """
- 打印评估摘要
- """
- print("\n" + "="*60)
- print("RAG系统评估结果摘要")
- print("="*60)
- print(f"总问题数: {metrics.get('total_questions', 0)}")
- print(f"成功评估数: {metrics.get('successful_evaluations', 0)}")
- print(f"成功率: {metrics.get('success_rate', 0):.2%}")
- print("\n核心指标:")
- print(f" 答案相关性: {metrics.get('avg_answer_relevancy', 0):.3f}")
- print(f" 上下文精确度: {metrics.get('avg_context_precision', 0):.3f}")
- print(f" 上下文召回率: {metrics.get('avg_context_recall', 0):.3f}")
- print(f" 答案正确性: {metrics.get('avg_answer_correctness', 0):.3f}")
- print(f"\n平均响应时间: {metrics.get('avg_response_time', 0):.2f}秒")
- print("="*60)
- async def main():
- """
- 主函数
- """
- evaluator = SimpleRAGEvaluator()
-
- try:
- # 运行评估
- metrics = await evaluator.run_evaluation()
-
- # 打印摘要
- evaluator.print_summary(metrics)
-
- # 保存结果
- evaluator.save_results(metrics)
-
- except Exception as e:
- logger.error(f"评估过程中发生错误: {e}")
- raise
- if __name__ == "__main__":
- asyncio.run(main())
|