test_new_new.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import json
  2. import aiohttp
  3. import asyncio
  4. import re
  5. from typing import List, Dict, Tuple
  6. from datetime import datetime
  7. import os
  8. class RAGEvaluator:
  9. def __init__(self, rag_url: str = "http://localhost:6000/rag/chat",
  10. eval_url: str = "https://api.deepseek.com/chat/completions",
  11. eval_key: str = "sk-72f9d0e5bc894e1d828d73bdcc50ff0a"):
  12. self.rag_url = rag_url
  13. self.eval_url = eval_url
  14. self.eval_key = eval_key
  15. self.evaluation_logs = []
  16. # RAG请求模板
  17. self.rag_template = {
  18. # "appId": "2924812721300312064",
  19. "knowledgeIds": ["a2963496869283893248", "a2963501316240183296"],
  20. "knowledgeInfo": '{"param_desc":"strict","show_recall_result":true,"recall_method":"embedding","rerank_status":true,"rerank_model_name":"rerank","slice_config_type":"customized","rerank_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}],"recall_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}]}',
  21. "maxToken": 8192,
  22. "model": "Qwen3-30B",
  23. "temperature": "0.01",
  24. "topP": "0.5",
  25. "embeddingId": "multilingual-e5-large-instruct",
  26. "prompt": "你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{{知识}}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:\"进入登录页面(图1.1)\"需要从文字中删除图序,回复效果为:\"进入登录页面\";\"如图所示1.1\",回复效果为:\"如图所示\";\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:\"我想\",\"我认为\",\"think\"等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{{用户}}"
  27. }
  28. self.eval_threshold = 0.9 # 是否准确的阈值
  29. def load_test_data(self, file_path: str = "val_data.json") -> List[Dict]:
  30. with open(file_path, 'r', encoding='utf-8') as f:
  31. return json.load(f)
  32. async def query_rag(self, session: aiohttp.ClientSession, question: str) -> str:
  33. payload = self.rag_template.copy()
  34. payload["query"] = question
  35. async with session.post(
  36. self.rag_url,
  37. json=payload,
  38. timeout=aiohttp.ClientTimeout(total=30),
  39. headers={'Accept': 'text/event-stream'}
  40. ) as response:
  41. if response.status != 200:
  42. return ""
  43. answer = ""
  44. async for line in response.content:
  45. line_str = line.decode('utf-8').strip()
  46. if not line_str.startswith('data: '):
  47. continue
  48. data_str = line_str[6:]
  49. if data_str == '[DONE]':
  50. break
  51. try:
  52. event_data = json.loads(data_str)
  53. if isinstance(event_data, dict):
  54. if event_data.get('event') == 'add':
  55. answer = event_data.get('data', '')
  56. elif event_data.get('event') == 'finish':
  57. break
  58. elif 'content' in event_data:
  59. answer = event_data['content']
  60. except json.JSONDecodeError:
  61. if data_str != '[DONE]':
  62. answer = data_str
  63. return answer
  64. async def evaluate_answer(self, session: aiohttp.ClientSession, actual: str, expected: str) -> float:
  65. prompt = f"""评估答案相似度,返回0-1分数。
  66. 标准答案:{expected}
  67. 实际答案:{actual}
  68. 评估标准:内容准确性、信息完整性、语义相似性
  69. 只返回数字分数,如:0.85"""
  70. headers = {
  71. "Content-Type": "application/json",
  72. "Authorization": f"Bearer {self.eval_key}"
  73. }
  74. payload = {
  75. "model": "deepseek-chat",
  76. "messages": [{"role": "user", "content": prompt}],
  77. "temperature": 0.1,
  78. "max_tokens": 50
  79. }
  80. async with session.post(self.eval_url, json=payload, headers=headers) as response:
  81. result = await response.json()
  82. score_text = result['choices'][0]['message']['content'].strip()
  83. numbers = re.findall(r'\d+\.\d+|\d+', score_text)
  84. if numbers:
  85. score = float(numbers[0])
  86. return min(max(score, 0.0), 1.0)
  87. return 0.0
  88. async def evaluate_single(self, session: aiohttp.ClientSession, item: Dict,
  89. semaphore: asyncio.Semaphore) -> Tuple[bool, float]:
  90. async with semaphore:
  91. question = item['question']
  92. expected = item['answer']
  93. actual = await self.query_rag(session, question)
  94. score = await self.evaluate_answer(session, actual, expected)
  95. is_correct = score > self.eval_threshold
  96. log_entry = {
  97. "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
  98. "question": question,
  99. "expected_answer": expected,
  100. "actual_answer": actual,
  101. "score": round(score, 2),
  102. "is_correct": is_correct
  103. }
  104. self.evaluation_logs.append(log_entry)
  105. print(f"问题: {question}")
  106. print(f"期望: {expected}")
  107. print(f"实际: {actual}")
  108. print(f"得分: {score:.2f}")
  109. print("-" * 50)
  110. return is_correct, score
  111. async def run_evaluation(self, data_file: str = "val_data.json",
  112. max_concurrent: int = 3) -> float:
  113. data = self.load_test_data(data_file)
  114. print(f"开始评估 {len(data)} 个问题...")
  115. semaphore = asyncio.Semaphore(max_concurrent)
  116. connector = aiohttp.TCPConnector(limit=10, limit_per_host=5)
  117. async with aiohttp.ClientSession(connector=connector) as session:
  118. tasks = [self.evaluate_single(session, item, semaphore) for item in data]
  119. results = await asyncio.gather(*tasks, return_exceptions=True)
  120. correct_count = 0
  121. total_score = 0.0
  122. valid_count = 0
  123. for result in results:
  124. if isinstance(result, Exception):
  125. print(f"评估异常: {result}")
  126. continue
  127. is_correct, score = result
  128. valid_count += 1
  129. total_score += score
  130. if is_correct:
  131. correct_count += 1
  132. if valid_count == 0:
  133. return 0.0
  134. accuracy = correct_count / valid_count
  135. avg_score = total_score / valid_count
  136. print(f"\n=== 评估结果 ===")
  137. print(f"正确率: {correct_count}/{valid_count} = {accuracy:.2%}")
  138. print(f"平均分: {avg_score:.2f}")
  139. self.save_evaluation_results(accuracy, avg_score, correct_count, valid_count)
  140. return accuracy
  141. def save_evaluation_results(self, accuracy: float, avg_score: float,
  142. correct_count: int, valid_count: int):
  143. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  144. filename = f"rag_evaluation_{timestamp}.txt"
  145. os.makedirs("logs", exist_ok=True)
  146. filepath = os.path.join("logs", filename)
  147. with open(filepath, 'w', encoding='utf-8') as f:
  148. f.write("=" * 60 + "\n")
  149. f.write("RAG系统评估结果报告\n")
  150. f.write("=" * 60 + "\n")
  151. f.write(f"评估时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
  152. f.write(f"总问题数: {valid_count}\n")
  153. f.write(f"正确回答数: {correct_count}\n")
  154. f.write(f"正确率: {accuracy:.2%}\n")
  155. f.write(f"平均得分: {avg_score:.2f}\n")
  156. f.write(f"准确阈值: {self.eval_threshold}\n")
  157. f.write("=" * 60 + "\n\n")
  158. f.write("详细评估日志:\n")
  159. f.write("-" * 60 + "\n")
  160. for i, log in enumerate(self.evaluation_logs, 1):
  161. f.write(f"\n[{i}] 评估时间: {log['timestamp']}\n")
  162. f.write(f"问题: {log['question']}\n")
  163. f.write(f"期望答案: {log['expected_answer']}\n")
  164. f.write(f"实际答案: {log['actual_answer']}\n")
  165. f.write(f"评分: {log['score']}\n")
  166. f.write("-" * 60 + "\n")
  167. print(f"\n评估结果已保存到: {filepath}")
  168. async def main():
  169. evaluator = RAGEvaluator()
  170. await evaluator.run_evaluation()
  171. if __name__ == "__main__":
  172. asyncio.run(main())