test_new.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 简化的RAG评估脚本
  5. """
  6. import json
  7. import aiohttp
  8. import asyncio
  9. import time
  10. from typing import List, Dict
  11. class SimpleRAGEvaluator:
  12. def __init__(self):
  13. self.api_url = "http://localhost:6000/rag/chat"
  14. self.request_data = {
  15. "appId": "2924812721300312064",
  16. "desc": "高井信息员工手册,出差管理制度等",
  17. "isDeepThink": "N",
  18. "knowledgeIds": [
  19. "a2963496869283893248",
  20. "a2963501316240183296"
  21. ],
  22. "knowledgeInfo": '{"param_desc":"strict","show_recall_result":true,"recall_method":"embedding","rerank_status":true,"rerank_model_name":"rerank","slice_config_type":"customized","rerank_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}],"recall_index_type_list":[{"index_type_id":0,"knowledge_id":["a2963496869283893248","a2963501316240183296"]}]}',
  23. "maxToken": 8192,
  24. "model": "Qwen3-30B",
  25. "name": "高井信息公司管理制度",
  26. "params": {},
  27. "prompt": "你是一位知识检索助手,你必须并且只能从我发送的众多知识片段中寻找能够解决用户输入问题的最优答案,并且在执行任务的过程中严格执行规定的要求。\n\n知识片段如下:\n{{知识}}\n\n规定要求:\n- 找到答案就仅使用知识片段中的原文回答用户的提问;\n- 找不到答案就用自身知识并且告诉用户该信息不是来自文档;\n- 所引用的文本片段中所包含的示意图占位符必须进行返回,占位符格式参考:【示意图序号_编号】\n - 严禁输出任何知识片段中不存在的示意图占位符;\n - 输出的内容必须删除其中包含的任何图注、序号等信息。例如:\"进入登录页面(图1.1)\"需要从文字中删除图序,回复效果为:\"进入登录页面\";\"如图所示1.1\",回复效果为:\"如图所示\";\n- 格式规范\n - 文档中会出现包含表格的情况,表格是以图片标识符的形式呈现,表格中缺失数据时候返回空单元格;\n - 如果需要用到表格中的数据,以markdown格式输出表格中的数据;\n - 避免使用代码块语法回复信息;\n - 回复的开头语不要输出诸如:\"我想\",\"我认为\",\"think\"等相关语义的文本。\n\n严格执行规定要求,不要复述问题,直接开始回答。\n\n用户输入问题:\n{{用户}}",
  28. "status": "3",
  29. "temperature": "0.01",
  30. "topP": "0.5",
  31. "typeId": 40,
  32. "updateBy": "1",
  33. "visible": "0",
  34. "embeddingId": "multilingual-e5-large-instruct",
  35. "enable_think": False
  36. }
  37. def load_data(self, file_path: str = "val_data.json") -> List[Dict]:
  38. with open(file_path, 'r', encoding='utf-8') as f:
  39. return json.load(f)
  40. async def call_api(self, session: aiohttp.ClientSession, question: str) -> str:
  41. data = self.request_data.copy()
  42. data["query"] = question
  43. try:
  44. async with session.post(
  45. self.api_url,
  46. json=data,
  47. timeout=aiohttp.ClientTimeout(total=30),
  48. headers={'Accept': 'text/event-stream'}
  49. ) as response:
  50. if response.status == 200:
  51. full_answer = ""
  52. # 处理流式响应
  53. async for line in response.content:
  54. line_str = line.decode('utf-8').strip()
  55. if not line_str or line_str.startswith(':'):
  56. continue
  57. if line_str.startswith('data: '):
  58. data_str = line_str[6:] # 移除 'data: ' 前缀
  59. if data_str == '[DONE]':
  60. break
  61. # 尝试解析JSON格式
  62. try:
  63. event_data = json.loads(data_str)
  64. # 确保event_data是字典类型才进行处理
  65. if isinstance(event_data, dict):
  66. if event_data.get('event') == 'add':
  67. # 直接替换为最新的完整内容(而不是累积)
  68. chunk_data = event_data.get('data', '')
  69. if isinstance(chunk_data, str):
  70. full_answer = chunk_data # 直接替换,因为每个chunk包含完整内容
  71. elif event_data.get('event') == 'finish':
  72. break
  73. # 只有当没有event字段时才检查content字段
  74. elif 'content' in event_data and 'event' not in event_data:
  75. content_data = event_data['content']
  76. if isinstance(content_data, str):
  77. full_answer = content_data # 直接替换
  78. except json.JSONDecodeError:
  79. # 如果不是JSON格式,直接作为文本内容处理
  80. if data_str and data_str != '[DONE]':
  81. full_answer = data_str # 直接替换
  82. continue
  83. except Exception:
  84. continue
  85. return full_answer
  86. return ''
  87. except Exception as e:
  88. print(f"API调用失败: {e}")
  89. return ''
  90. async def call_openai_compatible_api(self, session: aiohttp.ClientSession, prompt: str) -> str:
  91. """调用 OpenAI 兼容 API 进行评估"""
  92. # OpenAI 兼容的 API 配置
  93. api_url = "https://api.deepseek.com/chat/completions"
  94. api_key = "sk-72f9d0e5bc894e1d828d73bdcc50ff0a"
  95. model_name = "deepseek-chat"
  96. headers = {
  97. "Content-Type": "application/json",
  98. "Authorization": f"Bearer {api_key}"
  99. }
  100. data = {
  101. "model": model_name,
  102. "messages": [{"role": "user", "content": prompt}],
  103. "temperature": 0.1,
  104. "max_tokens": 100,
  105. "stream": False
  106. }
  107. async with session.post(api_url, json=data, headers=headers, timeout=aiohttp.ClientTimeout(total=30)) as response:
  108. result = await response.json()
  109. return result['choices'][0]['message']['content'].strip()
  110. async def check_accuracy(self, session: aiohttp.ClientSession, actual: str, expected: str) -> float:
  111. prompt = f"""请评估以下两个答案的相似度和准确性,给出0到1之间的分数(保留2位小数)。
  112. 标准答案:{expected}
  113. 实际答案:{actual}
  114. 评估标准:
  115. 1. 内容准确性:实际答案是否正确回答了问题
  116. 2. 信息完整性:实际答案是否包含了标准答案的关键信息
  117. 3. 语义相似性:两个答案在语义上的相似程度
  118. 请只返回一个0到1之间的数字分数,不要包含其他文字说明。例如:0.85"""
  119. score_str = await self.call_openai_compatible_api(session, prompt)
  120. import re
  121. numbers = re.findall(r'\d+\.\d+|\d+', score_str)
  122. score = float(numbers[0])
  123. return min(max(score, 0.0), 1.0)
  124. async def evaluate_single_item(self, session: aiohttp.ClientSession, item: Dict, semaphore: asyncio.Semaphore) -> tuple:
  125. """评估单个问题项"""
  126. async with semaphore: # 限制并发数量
  127. question = item['question']
  128. expected = item['answer']
  129. # 并发调用API和评估
  130. actual = await self.call_api(session, question)
  131. accuracy_score = await self.check_accuracy(session, actual, expected)
  132. is_correct = accuracy_score > 0.5 # 阈值设为0.5
  133. print(f"问题: {question}")
  134. print(f"期望: {expected}")
  135. print(f"实际: {actual}")
  136. print(f"准确率: {accuracy_score:.2%}")
  137. print("-" * 50)
  138. return is_correct, accuracy_score
  139. async def evaluate(self, data_file: str = "val_data.json", max_concurrent: int = 5):
  140. """异步并发评估"""
  141. data = self.load_data(data_file)
  142. total = len(data)
  143. # 创建信号量来限制并发数量
  144. semaphore = asyncio.Semaphore(max_concurrent)
  145. # 创建aiohttp会话
  146. connector = aiohttp.TCPConnector(limit=20, limit_per_host=10)
  147. async with aiohttp.ClientSession(connector=connector) as session:
  148. # 创建所有任务
  149. tasks = [self.evaluate_single_item(session, item, semaphore) for item in data]
  150. # 并发执行所有任务
  151. results = await asyncio.gather(*tasks, return_exceptions=True)
  152. # 统计结果
  153. correct = 0
  154. valid_results = 0
  155. for result in results:
  156. if isinstance(result, Exception):
  157. print(f"评估出错: {result}")
  158. continue
  159. is_correct, accuracy_score = result
  160. valid_results += 1
  161. if is_correct:
  162. correct += 1
  163. accuracy = correct / valid_results if valid_results > 0 else 0
  164. print(f"\n总体准确率: {correct}/{valid_results} = {accuracy:.2%}")
  165. return accuracy
  166. async def main():
  167. evaluator = SimpleRAGEvaluator()
  168. await evaluator.evaluate()
  169. if __name__ == "__main__":
  170. asyncio.run(main())