chat_message.py 81 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691
  1. from rag.db import MilvusOperate, MysqlOperate
  2. from rag.load_model import *
  3. from rag.llm import VllmApi
  4. from config import redis_config_dict
  5. import json
  6. import re
  7. import gc
  8. import redis
  9. from utils.get_logger import setup_logger
  10. import random
  11. import concurrent.futures
  12. from concurrent.futures import ThreadPoolExecutor, as_completed
  13. import time
  14. import httpx
  15. # 停用词
  16. from rag.sensitive_words_detection import TxtSensitiveWordStore, SecurityTokenizer, build_warning
  17. # LightRAG
  18. from lightRAG import AsyncLightRAGManager
  19. store = TxtSensitiveWordStore("./rag/sensitive_words/")
  20. tokenizer_sensitive = SecurityTokenizer(store)
  21. # from rag.load_model import tokenizer
  22. logger = setup_logger(__name__)
  23. # rerank_model_mapping = {
  24. # "bge-reranker-v2-m3": (bce_rerank_tokenizer, bce_rerank_base_model),
  25. # "rerank": (bce_rerank_tokenizer, bce_rerank_base_model),
  26. # "Qwen3-Reranker-0.6B": (qwen_rerank_tokenizer, qwen_rerank_base_model)
  27. # }
  28. rerank_model_mapping_vllm = {
  29. "bge-reranker-v2-m3": (rerank_bge_url, rerank_bge_model),
  30. "Qwen3-Reranker-0.6B": (rerank_qwen_url, rerank_qwen_model)
  31. }
  32. redis_host = redis_config_dict.get("host")
  33. redis_port = redis_config_dict.get("port")
  34. redis_db = redis_config_dict.get("db")
  35. class ChatRetrieverRag:
  36. def __init__(self, chat_json: dict = None, chat_id: str=None):
  37. self.chat_id = chat_id
  38. self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
  39. if not chat_id:
  40. self.vllm_client = VllmApi(chat_json)
  41. def count_tokens(self,text):
  42. tokens = tokenizer.tokenize(text)
  43. return len(tokens)
  44. # 手动rerank
  45. # def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
  46. # rerank_list = []
  47. # for result in hybrid_search_result:
  48. # rerank_list.append([query, result["content"]])
  49. # tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
  50. # # 重排序
  51. # rerank_model.eval()
  52. # inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
  53. # inputs = {k: v.to(device) for k, v in inputs.items()}
  54. # with torch.no_grad():
  55. # logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
  56. # logits = logits.detach().cpu()
  57. # scores = torch.sigmoid(logits).tolist()
  58. # # scores_list = scores.tolist()
  59. # logger.info(f"重排序的得分:{scores}")
  60. # # sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
  61. # threshold = 0.3
  62. # sorted_pairs = sorted( ((s, r) for s, r in zip(scores, hybrid_search_result) if s >= threshold), key=lambda x: x[0], reverse=True )
  63. # if not sorted_pairs:
  64. # sorted_scores = []
  65. # sorted_search = []
  66. # else:
  67. # sorted_scores, sorted_search = zip(*sorted_pairs)
  68. # logger.info(f"过滤后的{sorted_scores}")
  69. # # sorted_scores, sorted_search = zip(*sorted_pairs)
  70. # sorted_scores_list = list(sorted_scores)
  71. # sorted_search_list = list(sorted_search)
  72. # for score, search in zip(sorted_scores_list, sorted_search_list):
  73. # search["rerank_score"] = score
  74. # del inputs, logits
  75. # gc.collect()
  76. # torch.cuda.empty_cache()
  77. # search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list)))
  78. # return search_result
  79. # rerank(url)
  80. async def rerank_result_async(self, query, top_k, hybrid_search_result, rerank_embedding_name):
  81. """
  82. Args:
  83. query (str): 查询文本
  84. top_k (int): 最终返回结果数量
  85. hybrid_search_result (list[dict]): 原始搜索结果列表
  86. rerank_embedding_name (str): rerank 模型
  87. Returns:
  88. list[dict]: 添加了 "rerank_score" 字段的结果
  89. """
  90. # 查模型信息
  91. rerank_url, rerank_model = rerank_model_mapping_vllm.get(rerank_embedding_name)
  92. if not hybrid_search_result:
  93. return []
  94. # 构建候选文档列表
  95. documents = [r["content"] for r in hybrid_search_result]
  96. payload = {
  97. "model": rerank_model,
  98. "query": query,
  99. "documents": documents
  100. }
  101. service_url = f"{rerank_url}/v1/rerank"
  102. # 使用异步 HTTP 请求
  103. try:
  104. async with httpx.AsyncClient(timeout=60.0) as client:
  105. resp = await client.post(service_url, json=payload)
  106. resp.raise_for_status()
  107. result = resp.json()
  108. except Exception as e:
  109. logger.error(f"调用 rerank 服务失败: {e}")
  110. # 出错 fallback
  111. scores = [0.0] * len(documents)
  112. else:
  113. # 解析分数
  114. if "results" in result:
  115. scores = [item.get("relevance_score", 0.0) for item in result["results"]]
  116. elif "data" in result:
  117. scores = [item.get("score", 0.0) for item in result["data"]]
  118. else:
  119. logger.warning(f"未知 rerank 返回格式: {result}")
  120. scores = [0.0] * len(documents)
  121. logger.info(f"重排序的得分:{scores}")
  122. # 构建分数与搜索结果列表
  123. sorted_pairs = list(zip(scores, hybrid_search_result))
  124. # # 阈值过滤
  125. # threshold = 0.2
  126. # filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold]
  127. # if not filtered_pairs:
  128. # filtered_pairs = sorted_pairs
  129. # logger.info(f"过滤后的{filtered_pairs}")
  130. # # 按分数降序排序
  131. # filtered_pairs.sort(key=lambda x: x[0], reverse=True)
  132. threshold = 0.3
  133. # sorted_pairs = sorted(
  134. # ((s, r) for s, r in sorted_pairs if s >= threshold) or sorted_pairs,
  135. # key=lambda x: x[0],
  136. # reverse=True
  137. # )
  138. filtered_pairs = [(s, r) for s, r in sorted_pairs if s >= threshold]
  139. if not filtered_pairs:
  140. filtered_pairs = sorted_pairs
  141. sorted_pairs = sorted(
  142. filtered_pairs,
  143. key=lambda x: x[0],
  144. reverse=True
  145. )
  146. if not sorted_pairs:
  147. sorted_scores_list = []
  148. sorted_search_list = []
  149. else:
  150. sorted_scores_list, sorted_search_list = zip(*sorted_pairs)
  151. logger.info(f"过滤后的{sorted_scores_list}")
  152. # 写入 rerank_score
  153. # sorted_scores_list, sorted_search_list = zip(*filtered_pairs)
  154. sorted_scores_list, sorted_search_list = list(sorted_scores_list), list(sorted_search_list)
  155. for score, search in zip(sorted_scores_list, sorted_search_list):
  156. search["rerank_score"] = score
  157. # 清理 GPU 内存
  158. gc.collect()
  159. torch.cuda.empty_cache()
  160. # 随机采样 top_k
  161. search_result = random.sample(sorted_search_list, min(top_k, len(sorted_search_list)))
  162. # logger.info(f"query='{query}' rerank结果: {[r['rerank_score'] for r in search_result]}")
  163. return search_result
  164. async def segment_query_with_llm(self, query, max_retries=3):
  165. """
  166. 使用LLM对长问题进行切分,提取多个子问题
  167. """
  168. segment_prompt = f"""
  169. 你是一个专业的问题分析助手。请将下面的复杂问题分解为3-5个具体的子问题,每个子问题都应该是独立的、可检索的。
  170. 原问题:{query}
  171. 请按照以下JSON格式返回:
  172. {{
  173. "sub_questions": [
  174. "子问题1",
  175. "子问题2",
  176. "子问题3"
  177. ]
  178. }}
  179. 要求:
  180. 1. 每个子问题都要具体明确
  181. 2. 子问题之间相互独立
  182. 3. 覆盖原问题的主要方面
  183. 4. 适合向量检索
  184. """
  185. for attempt in range(max_retries):
  186. try:
  187. logger.info(f"第{attempt + 1}次尝试问题切分")
  188. # 构造请求参数
  189. messages = [{"role": "user", "content": segment_prompt}]
  190. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  191. # 调用LLM
  192. response_text = ""
  193. async for chunk in self.vllm_client.chat(
  194. prompt=segment_prompt,
  195. model=model,
  196. temperature=0.1,
  197. top_p=0.9,
  198. max_tokens=1024,
  199. stream=True,
  200. history=[]
  201. ):
  202. # print(chunk)
  203. if chunk.get("event") == "add":
  204. response_text += chunk.get("data", "")
  205. logger.info(f"LLM返回的切分结果:{response_text}")
  206. # 解析JSON响应
  207. import json
  208. if "```json" in response_text:
  209. json_pattern = r'```json\s*(.*?)```'
  210. matches = re.findall(json_pattern, response_text, re.DOTALL)
  211. if matches:
  212. response_text = matches[0]
  213. result = json.loads(response_text.strip())
  214. sub_questions = result.get("sub_questions", [])
  215. if sub_questions and len(sub_questions) > 0:
  216. logger.info(f"成功切分为{len(sub_questions)}个子问题:{sub_questions}")
  217. return sub_questions
  218. else:
  219. logger.warning("LLM返回的子问题列表为空")
  220. except Exception as e:
  221. logger.error(f"第{attempt + 1}次问题切分失败:{str(e)}")
  222. if attempt == max_retries - 1:
  223. logger.error("问题切分达到最大重试次数,使用原问题")
  224. return [query]
  225. return [query]
  226. def request_milvus_multi(self, collection_name_list, querys, rag_embedding_name, top, mode):
  227. """
  228. 支持多query的向量检索方法
  229. """
  230. results = []
  231. doc_id_to_collection = {}
  232. def query_collection(collection_name, query):
  233. return collection_name, query, MilvusOperate(
  234. collection_name=collection_name,
  235. embedding_name=rag_embedding_name
  236. )._search(query, k=top, mode=mode)
  237. tasks = [(c, q) for c in collection_name_list for q in querys]
  238. max_workers = min(20, len(tasks))
  239. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  240. future_to_task = {
  241. executor.submit(query_collection, c, q): (c, q) for c, q in tasks
  242. }
  243. for future in as_completed(future_to_task):
  244. collection_name, query = future_to_task[future]
  245. try:
  246. _, _, search_result = future.result()
  247. for doc in search_result:
  248. doc_id = doc["doc_id"]
  249. if doc_id not in doc_id_to_collection:
  250. doc_id_to_collection[doc_id] = collection_name
  251. results.extend(search_result)
  252. except Exception as e:
  253. logger.error(f"查询集合 {collection_name} 对 query {query} 出错: {e}")
  254. return results, doc_id_to_collection
  255. def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode):
  256. search_multi_collection_list = []
  257. search_doc_id_to_knowledge_id_dict = {}
  258. # 多线程查询
  259. def query_collection(collection_name):
  260. return collection_name, MilvusOperate(
  261. collection_name=collection_name,
  262. embedding_name=rag_embedding_name
  263. )._search(query, k=top, mode=mode)
  264. max_workers = min(10, len(collection_name_list))
  265. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  266. future_to_name = {
  267. executor.submit(query_collection, name): name
  268. for name in collection_name_list
  269. }
  270. for future in concurrent.futures.as_completed(future_to_name):
  271. collection_name = future_to_name[future]
  272. try:
  273. _, hybrid_search_result = future.result()
  274. for result in hybrid_search_result:
  275. doc_id = result.get("doc_id")
  276. if doc_id not in search_doc_id_to_knowledge_id_dict:
  277. search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
  278. search_multi_collection_list.extend(hybrid_search_result)
  279. except Exception as e:
  280. logger.error(f"查询集合 {collection_name} 时出错: {str(e)}")
  281. continue
  282. return search_multi_collection_list, search_doc_id_to_knowledge_id_dict
  283. async def retriever_result(self, chat_json, rewrite = None):
  284. mode_search_mode = {
  285. "embedding": "dense",
  286. "keyword": "sparse",
  287. "mixed": "hybrid",
  288. }
  289. collection_name_list = chat_json.get("knowledgeIds")
  290. top = chat_json.get("sliceCount", 15)
  291. retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
  292. rag_embedding_name = chat_json.get("embeddingId")
  293. mode_rag = retriever_info.get("recall_method")
  294. mode = mode_search_mode.get(mode_rag, "hybrid")
  295. rerank_embedding_name = retriever_info.get("rerank_model_name")
  296. rerank_status = retriever_info.get("rerank_status")
  297. if rewrite:
  298. chat_json["query"] = rewrite
  299. query = chat_json.get("query")
  300. query_list = [query]
  301. # search_multi_collection_list = []
  302. # search_doc_id_to_knowledge_id_dict = {}
  303. # for collection_name in collection_name_list:
  304. # hybrid_search_result = MilvusOperate(
  305. # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
  306. # for result in hybrid_search_result:
  307. # doc_id = result.get("doc_id")
  308. # if doc_id not in search_doc_id_to_knowledge_id_dict:
  309. # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
  310. # search_multi_collection_list.extend(hybrid_search_result)
  311. """
  312. search_multi_collection_list信息格式:
  313. {
  314. "doc_id": "07917ce4-ce63-11f0-ad4f-ecd68ae97d92",
  315. "chunk_id": "7a4aba3e-ce63-11f0-8a8d-ecd68ae97d92",
  316. "content": "## 3.13 高处作业 24\n",
  317. "Father_Chapter": "3->3.6",
  318. "Chapter": "3->3.6->3.6.2",
  319. "metadata": {
  320. "source": "超长ceshi.pdf",
  321. "chunk_index": 14,
  322. "chunk_len": 16
  323. },
  324. "score": 0.016393441706895828,
  325. "rerank_score": 1
  326. }
  327. """
  328. start_time = time.time()
  329. two_time = -1
  330. # if rewrite:
  331. # query = rewrite
  332. # query_list = [rewrite]
  333. # 1. 检查是否需要问题切分
  334. problem_segmentation = chat_json.get("Problem_segmentation", False)
  335. if problem_segmentation:
  336. logger.info("启用问题切分模式")
  337. # 使用LLM切分问题
  338. query_list = await self.segment_query_with_llm(query)
  339. logger.info(f"切分后的问题列表:{query_list}")
  340. # [query + "\n" + summary if summary else query for query in query_list]
  341. # 多轮检索
  342. # if not query_list:
  343. # query_list = [query + "\n" + summary] if summary else [query]
  344. logger.info(f"query_list:{query_list}")
  345. search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus_multi(
  346. collection_name_list, query_list, rag_embedding_name, top, mode
  347. )
  348. # # 单次检索
  349. # query_list = query
  350. # search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus(
  351. # collection_name_list, rag_embedding_name, query, top, mode
  352. # )
  353. logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}")
  354. if len(search_multi_collection_list) <= 0:
  355. return [], search_doc_id_to_knowledge_id_dict
  356. # 2. 重排序
  357. if rerank_status:
  358. # rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name)
  359. rerank_result_list = await self.rerank_result_async(query, top, search_multi_collection_list, rerank_embedding_name)
  360. else:
  361. for result in search_multi_collection_list:
  362. result["rerank_score"] = 1
  363. rerank_result_list = random.sample(search_multi_collection_list, min(top, len(search_multi_collection_list)))
  364. # 3. 收集文档ID,查询MySQL判断是否需要章节扩展
  365. doc_ids = list({r["doc_id"] for r in rerank_result_list})
  366. # 收集知识库ID,用于知识图谱检索
  367. knowledge_ids = list({r["metadata"].get("knowledge_id", "") for r in rerank_result_list})
  368. logger.info(f"检索到的知识库id{knowledge_ids}")
  369. mysql_client = MysqlOperate()
  370. enabled_doc_ids = mysql_client.query_parent_generation_enabled(doc_ids)
  371. enabled_kn_rs = mysql_client.query_knowledge_by_ids(knowledge_ids)
  372. enabled_kn_gp_ids = set()
  373. for r in enabled_kn_rs:
  374. if r["knowledge_graph"]:
  375. enabled_kn_gp_ids.add(r["knowledge_id"])
  376. if enabled_kn_gp_ids:
  377. gp_add_list = []
  378. enabled_kn_gp_ids = list(enabled_kn_gp_ids)
  379. gp_time = time.time()
  380. logger.info(f"需要查询知识图谱的知识库:{enabled_kn_gp_ids}")
  381. # if chat_json.get("lightrag"):
  382. for kn_id in enabled_kn_gp_ids:
  383. result = await AsyncLightRAGManager().retrieve(label=kn_id, query=query)
  384. chunks_add = result["data"].get("chunks", [])
  385. if chunks_add:
  386. # chunk_dict = {}
  387. for chunk_add in chunks_add:
  388. # chunk_dict["content"] = chunk_add
  389. gp_add_list.append({
  390. "doc_id": None,
  391. "chunk_id": None,
  392. "content": chunk_add["content"],
  393. "metadata": {},
  394. "rerank_score": 1,
  395. "revision_status": None,
  396. "section": None,
  397. "parent_section": None,
  398. })
  399. logger.info(f"知识图谱扩展的内容:{gp_add_list}")
  400. rerank_result_list.extend(gp_add_list)
  401. if enabled_doc_ids:
  402. tmp_time = time.time()
  403. logger.info(f"需要章节扩展的文档: {enabled_doc_ids}")
  404. existing_chunk_ids = {r["chunk_id"] for r in rerank_result_list}
  405. # 收集需要扩展的 (collection_name, doc_id, Father_Chapter) 组合-集合名称-文档id-父级标题
  406. father_chapter_set = set()
  407. for result in rerank_result_list:
  408. doc_id = result.get("doc_id")
  409. if doc_id not in enabled_doc_ids:
  410. continue
  411. father_chapter = result.get("Father_Chapter")
  412. collection_name = search_doc_id_to_knowledge_id_dict.get(doc_id)
  413. if father_chapter and collection_name:
  414. father_chapter_set.add((collection_name, doc_id, father_chapter))
  415. # 4. 二次检索并追加结果
  416. for collection_name, doc_id, father_chapter in father_chapter_set:
  417. chapter_results = MilvusOperate(
  418. collection_name=collection_name, embedding_name=rag_embedding_name
  419. )._query_by_scalar_field(doc_id, "Father_Chapter", father_chapter)
  420. for r in chapter_results:
  421. if r["chunk_id"] not in existing_chunk_ids:
  422. r["rerank_score"] = 1
  423. rerank_result_list.append(r)
  424. existing_chunk_ids.add(r["chunk_id"])
  425. two_time = time.time() - tmp_time
  426. logger.info(f"父子章节扩展后的结果数: {len(rerank_result_list)}")
  427. up_time = time.time() - start_time
  428. two_status = "未启动父子召回" if two_time < 0 else f"启动父子召回耗时:{two_time}"
  429. logger.info(f"检索耗时:{up_time}({two_status})")
  430. return rerank_result_list, search_doc_id_to_knowledge_id_dict
  431. async def generate_rag_response(self, chat_json, chunk_content):
  432. # logger.info(f"rag聊天的请求参数:{chat_json}")
  433. # retriever_result_list = self.retriever_result(chat_json)
  434. # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  435. prompt = chat_json.get("prompt")
  436. query = chat_json.get("query")
  437. stream = chat_json.get("stream", True)
  438. # chunk_content = ""
  439. # for retriever in retriever_result_list:
  440. # chunk_content += retriever["content"]
  441. prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
  442. logger.info(f"请求的提示词:{prompt}")
  443. temperature = float(chat_json.get("temperature", 0.6))
  444. top_p = float(chat_json.get("topP", 0.7))
  445. max_token = chat_json.get("maxToken", 4096)
  446. enable_think = chat_json.get("enable_think", False)
  447. history = chat_json.get("messages", [])
  448. if history:
  449. history[-1]["content"] = prompt
  450. # 调用模型获取返回的结果
  451. # model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
  452. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  453. chat_resp = ""
  454. async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=stream, history=history, enable_think=enable_think):
  455. logger.info(f"chat_message中接收到的LLM信息{chunk}")
  456. if stream:
  457. if chunk.get("event") == "add":
  458. chat_resp += chunk.get("data")
  459. chunk["data"] = chat_resp
  460. yield chunk
  461. else:
  462. yield chunk
  463. else:
  464. chunk
  465. def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict):
  466. doc_info = {}
  467. chunk_content = ""
  468. # 组织成每个doc对应的json格式
  469. for retriever in retriever_result_list:
  470. chunk_text = retriever["content"]
  471. chunk_content += chunk_text
  472. doc_id = retriever["doc_id"]
  473. chunk_id = retriever["chunk_id"]
  474. rerank_score = retriever["rerank_score"]
  475. doc_name = retriever.get("metadata").get("source","")
  476. if doc_id in doc_info:
  477. c_d = {
  478. "chunk_id": chunk_id,
  479. "rerank_score": rerank_score,
  480. "chunk_len": len(chunk_text),
  481. "document_name": retriever.get("document_name"),
  482. "revision_status": retriever.get("revision_status"),
  483. "ref_slice_id": retriever.get("ref_slice_id"),
  484. "revision_group_id": retriever.get("revision_group_id"),
  485. "section": retriever.get("section"),
  486. "parent_section": retriever.get("parent_section"),
  487. }
  488. doc_info[doc_id]["chunk_list"].append(c_d)
  489. # doc_info[doc_id]["rerank_score"].append(rerank_score)
  490. else:
  491. c_d = {
  492. "chunk_id": chunk_id,
  493. "rerank_score": rerank_score,
  494. "chunk_len": len(chunk_text),
  495. "document_name": retriever.get("document_name"),
  496. "revision_status": retriever.get("revision_status"),
  497. "ref_slice_id": retriever.get("ref_slice_id"),
  498. "revision_group_id": retriever.get("revision_group_id"),
  499. "section": retriever.get("section"),
  500. "parent_section": retriever.get("parent_section"),
  501. }
  502. doc_info[doc_id] = {
  503. "doc_name": doc_name,
  504. "chunk_list": [c_d],
  505. }
  506. # doc_list = []
  507. knowledge_info_dict = {}
  508. """
  509. {
  510. "knowledge_id": [{}, {}],
  511. "knowledge_id"
  512. }
  513. """
  514. for k, v in doc_info.items():
  515. d = {}
  516. knowledge_id = search_doc_id_to_knowledge_id_dict.get(k)
  517. d["doc_id"] = k
  518. d["doc_name"] = v.get("doc_name")
  519. d["chunk_nums"] = len(v.get("chunk_list"))
  520. d["chunk_info_list"] = v.get("chunk_list")
  521. # d["chunk_len"] = len(v.get("chunk_id_list"))
  522. if knowledge_id in knowledge_info_dict:
  523. knowledge_info_dict[knowledge_id].append(d)
  524. else:
  525. knowledge_info_dict[knowledge_id] = [d]
  526. # doc_list.append(d)
  527. return chunk_content, knowledge_info_dict
  528. def _inject_revision_rules_to_prompt(self, prompt: str, rules_text: str):
  529. if not rules_text:
  530. return prompt
  531. if not prompt:
  532. return rules_text
  533. # insert_anchor = "用户输入问题"
  534. # if insert_anchor in prompt:
  535. # idx = prompt.find(insert_anchor)
  536. # return prompt[:idx] + rules_text + "\n\n" + prompt[idx:]
  537. # if "{用户}" in prompt:
  538. # return prompt.replace("{用户}", rules_text + "\n\n{用户}", 1)
  539. return prompt + "\n\n" + rules_text
  540. def _normalize_revision_status(self, value):
  541. return value
  542. # if value is None:
  543. # return None
  544. # if isinstance(value, bool):
  545. # return int(value)
  546. # if isinstance(value, int):
  547. # return value
  548. # try:
  549. # s = str(value).strip()
  550. # if s == "":
  551. # return None
  552. # return int(s)
  553. # except Exception:
  554. # return None
  555. def _normalize_slice_id(self, value):
  556. # if value is None:
  557. # return None
  558. # s = str(value).strip()
  559. # if s == "":
  560. # return None
  561. # if s.isdigit() and len(s) < 20:
  562. # return s.zfill(20)
  563. return value
  564. def _format_slice_block(self, doc_name: str, version_tag: str, section: str, parent_section: str, text: str, extra_title_suffix: str = ""):
  565. title = (doc_name or "未知文档")
  566. if version_tag:
  567. title += f"({version_tag})"
  568. if extra_title_suffix:
  569. title += extra_title_suffix
  570. lines = [f"> {title}"]
  571. if parent_section:
  572. lines.append(f"{parent_section}")
  573. if section and section != parent_section:
  574. lines.append(f"{section}" if parent_section else f"{section}")
  575. if text:
  576. lines.append(str(text).strip())
  577. return "\n".join(lines).strip() + "\n\n"
  578. # def _build_llm_chunk_content(self, retriever_result_list):
  579. # slice_map = {}
  580. # for r in retriever_result_list:
  581. # sid = self._normalize_slice_id(r.get("chunk_id"))
  582. # if sid:
  583. # # def _build_llm_chunk_content(self, retriever_result_list):
  584. # # slice_map = {}
  585. # # for r in retriever_result_list:
  586. # # sid = self._normalize_slice_id(r.get("chunk_id"))
  587. # # if sid:
  588. # # slice_map[sid] = r
  589. # original_to_revision_ids = {}
  590. # revision_ids = set()
  591. # for r in retriever_result_list:
  592. # # if self._normalize_revision_status(r.get("revision_status")) == 1 and r.get("ref_slice_id"):
  593. # if r.get("revision_status") == "1" and r.get("ref_slice_id"):
  594. # original_to_revision_ids[r.get("chunk_id")] = r.get("ref_slice_id")
  595. # revision_ids.add(r.get("ref_slice_id"))
  596. # if sid:
  597. # slice_map[sid] = r
  598. # visited = set()
  599. # blocks = []
  600. # for r in retriever_result_list:
  601. # sid = self._normalize_slice_id(r.get("chunk_id"))
  602. # if not sid or sid in visited:
  603. # continue
  604. # revision_status = r.get("revision_status")
  605. # ref_slice_id = self._normalize_slice_id(r.get("ref_slice_id"))
  606. # if revision_status == "1" and ref_slice_id:
  607. # original = r
  608. # revision = slice_map.get(ref_slice_id)
  609. # blocks.append(
  610. # self._format_slice_block(
  611. # original.get("document_name"),
  612. # "原始版",
  613. # original.get("section"),
  614. # original.get("parent_section"),
  615. # original.get("content"),
  616. # extra_title_suffix=" [已修订]",
  617. # )
  618. # )
  619. # visited.add(sid)
  620. # revision_sid = self._normalize_slice_id(revision.get("chunk_id")) if revision else None
  621. # if revision and revision_sid and revision_sid not in visited:
  622. # blocks.append(
  623. # self._format_slice_block(
  624. # revision.get("document_name"),
  625. # "修订版",
  626. # revision.get("section"),
  627. # revision.get("parent_section"),
  628. # revision.get("content"),
  629. # )
  630. # )
  631. # visited.add(revision_sid)
  632. # continue
  633. # if revision_status == "0":
  634. # deprecated_text = (r.get("content") or "")
  635. # deprecated_text = deprecated_text + "\n\n【已废弃】此切片已弃用,请优先参考修订版或最新文档。"
  636. # blocks.append(
  637. # self._format_slice_block(
  638. # r.get("document_name"),
  639. # "已废弃",
  640. # r.get("section"),
  641. # r.get("parent_section"),
  642. # deprecated_text,
  643. # )
  644. # )
  645. # visited.add(sid)
  646. # continue
  647. # if sid in revision_ids:
  648. # continue
  649. # blocks.append(
  650. # self._format_slice_block(
  651. # r.get("document_name"),
  652. # "",
  653. # r.get("section"),
  654. # r.get("parent_section"),
  655. # r.get("content"),
  656. # )
  657. # )
  658. # visited.add(sid)
  659. # return "".join(blocks).strip()
  660. def _build_llm_chunk_content(self, retriever_result_list):
  661. """
  662. 构建 LLM 用的切片内容,仅发送最初检索到的向量数据库内容块,每个块添加 chunk_id。
  663. """
  664. logger.info(f"_build_llm_chunk_content 接收到的切片数量: {len(retriever_result_list)}")
  665. blocks = []
  666. for r in retriever_result_list:
  667. chunk_id = r.get("chunk_id")
  668. content = r.get("content", "")
  669. doc_name = r.get("document_name") or r.get("metadata", {}).get("source", "未知文档")
  670. # 格式:[chunk_id: xxx] 文档名\n内容
  671. block = f"[chunk_id: {chunk_id}]\n> {doc_name}\n{content}\n\n"
  672. blocks.append(block)
  673. return "".join(blocks).strip()
  674. # def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict):
  675. # mysql_client = MysqlOperate()
  676. # slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")]
  677. # slice_ids = [s for s in slice_ids if s]
  678. # slice_rows = []
  679. # success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids)
  680. # if success:
  681. # slice_rows = rows_or_error
  682. # slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")}
  683. # logger.info(slice_rows)
  684. # has_revision_or_deprecated = False
  685. # revision_slice_ids = set()
  686. # original_to_revision_id = {}
  687. # for r in retriever_result_list:
  688. # sid = self._normalize_slice_id(r.get("chunk_id"))
  689. # info = slice_info_map.get(sid)
  690. # if info:
  691. # normalized_revision_status = self._normalize_revision_status(info.get("revision_status"))
  692. # r["revision_status"] = normalized_revision_status
  693. # r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id"))
  694. # r["section"] = info.get("section")
  695. # r["parent_section"] = info.get("parent_section")
  696. # if normalized_revision_status in ("0", "1"):
  697. # has_revision_or_deprecated = True
  698. # logger.info("修订切片")
  699. # else:
  700. # r.setdefault("revision_status", None)
  701. # r.setdefault("ref_slice_id", None)
  702. # r.setdefault("section", None)
  703. # r.setdefault("parent_section", None)
  704. # ref_id = self._normalize_slice_id(r.get("ref_slice_id"))
  705. # sid = self._normalize_slice_id(r.get("chunk_id"))
  706. # if r.get("revision_status") == "1" and ref_id and sid:
  707. # original_to_revision_id[sid] = ref_id
  708. # revision_slice_ids.add(ref_id)
  709. # r["revision_group_id"] = sid
  710. # else:
  711. # r.setdefault("revision_group_id", None)
  712. # ref_slice_rows = []
  713. # if revision_slice_ids:
  714. # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
  715. # if success:
  716. # ref_slice_rows = ref_rows_or_error
  717. # revision_id_to_original_id = {}
  718. # for original_id, revision_id in original_to_revision_id.items():
  719. # if revision_id:
  720. # revision_id_to_original_id[revision_id] = original_id
  721. # for r in retriever_result_list:
  722. # original_id = revision_id_to_original_id.get(r.get("chunk_id"))
  723. # if original_id:
  724. # r["revision_group_id"] = original_id
  725. # r["revision_of_slice_id"] = original_id
  726. # else:
  727. # r.setdefault("revision_of_slice_id", None)
  728. # for row in (slice_rows + ref_slice_rows):
  729. # doc_id = row.get("document_id")
  730. # knowledge_id = row.get("knowledge_id")
  731. # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict:
  732. # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id
  733. # success, ref_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
  734. # if success:
  735. # ref_slice_rows = ref_rows_or_error
  736. # revision_id_to_original_id = {}
  737. # for original_id, revision_id in original_to_revision_id.items():
  738. # if revision_id:
  739. # revision_id_to_original_id[revision_id] = original_id
  740. # for r in retriever_result_list:
  741. # original_id = revision_id_to_original_id.get(r.get("chunk_id"))
  742. # if original_id:
  743. # r["revision_group_id"] = original_id
  744. # r["revision_of_slice_id"] = original_id
  745. # else:
  746. # r.setdefault("revision_of_slice_id", None)
  747. # for row in (slice_rows + ref_slice_rows):
  748. # doc_id = row.get("document_id")
  749. # knowledge_id = row.get("knowledge_id")
  750. # if doc_id and knowledge_id and doc_id not in search_doc_id_to_knowledge_id_dict:
  751. # search_doc_id_to_knowledge_id_dict[doc_id] = knowledge_id
  752. # existing_slice_ids = {r.get("chunk_id") for r in retriever_result_list if r.get("chunk_id")}
  753. # existing_slice_ids = {self._normalize_slice_id(s) for s in existing_slice_ids}
  754. # existing_slice_ids = {s for s in existing_slice_ids if s}
  755. # ref_map = {r.get("slice_id"): r for r in ref_slice_rows if r.get("slice_id")}
  756. # for revision_id in revision_slice_ids:
  757. # if revision_id in existing_slice_ids:
  758. # continue
  759. # info = ref_map.get(revision_id)
  760. # if not info:
  761. # continue
  762. # doc_id = info.get("document_id")
  763. # text = info.get("slice_text") or ""
  764. # metadata = {"source": "", "chunk_index": None, "chunk_len": len(text)}
  765. # original_id = revision_id_to_original_id.get(revision_id)
  766. # retriever_result_list.append(
  767. # {
  768. # "doc_id": doc_id,
  769. # "chunk_id": revision_id,
  770. # "content": text,
  771. # "Father_Chapter": "",
  772. # "Chapter": "",
  773. # "metadata": metadata,
  774. # "score": 0,
  775. # "rerank_score": 1,
  776. # "revision_status": None,
  777. # "ref_slice_id": None,
  778. # "revision_group_id": original_id,
  779. # "revision_of_slice_id": original_id,
  780. # "section": info.get("section"),
  781. # "parent_section": info.get("parent_section"),
  782. # }
  783. # )
  784. # existing_slice_ids.add(revision_id)
  785. # doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")})
  786. # success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids)
  787. # doc_name_map = doc_name_map_or_error if success else {}
  788. # for r in retriever_result_list:
  789. # doc_id = r.get("doc_id")
  790. # doc_name = doc_name_map.get(doc_id)
  791. # r["document_name"] = doc_name
  792. # if r.get("metadata") is None:
  793. # r["metadata"] = {}
  794. # if doc_name:
  795. # r["metadata"]["source"] = doc_name
  796. # rules_text = ""
  797. # if has_revision_or_deprecated:
  798. # rules_text = (
  799. # "切片中可能含有废弃或修订切片。\n"
  800. # "- 若切片标题包含“已废弃”,表示该内容已不再适用;如需引用请明确说明其已废弃,并优先给出最新/修订信息(若有)。\n"
  801. # "- 若切片标题包含“原始版/修订版”,表示同一主题存在版本差异;回答时优先采用修订版,必要时对比说明差异。\n"
  802. # "- 输出时可参考如下结构组织:\n"
  803. # "> 文档A(原始版)[已修订]\n"
  804. # "> 文档B(修订版)\n"
  805. # "> 文档C(已废弃)\n"
  806. # )
  807. # chunk_content_for_llm = self._build_llm_chunk_content(retriever_result_list)
  808. # return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, rules_text
  809. def _apply_revision_and_doc_name(self, chat_json, retriever_result_list, search_doc_id_to_knowledge_id_dict):
  810. """
  811. 处理切片的修订状态、原始/修订关系、废弃状态,并补充文档名称。
  812. 返回更新后的 retriever_result_list、search_doc_id_to_knowledge_id_dict、
  813. LLM 拼接内容以及修订映射信息。
  814. 新逻辑:
  815. 1. 查询修订切片的完整内容
  816. 2. 将修订后的切片给LLM使用
  817. 3. 保留原始切片信息用于Redis存储
  818. """
  819. mysql_client = MysqlOperate()
  820. # 收集所有 chunk_id
  821. slice_ids = [self._normalize_slice_id(r.get("chunk_id")) for r in retriever_result_list if r.get("chunk_id")]
  822. slice_ids = [s for s in slice_ids if s]
  823. # 查询切片修订信息
  824. slice_rows = []
  825. if slice_ids:
  826. success, rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(slice_ids)
  827. if success:
  828. slice_rows = rows_or_error
  829. # 构建 slice_id -> row 映射
  830. slice_info_map = {self._normalize_slice_id(r.get("slice_id")): r for r in slice_rows if r.get("slice_id")}
  831. # 构建原始->修订映射和修订映射信息
  832. original_to_revision_id = {}
  833. revision_slice_ids = set() # 收集所有修订切片ID
  834. chunk_revision_map = {} # 用于存储修订关系
  835. for r in retriever_result_list:
  836. # 当前切片id
  837. sid = self._normalize_slice_id(r.get("chunk_id"))
  838. # 切片id对应的row
  839. info = slice_info_map.get(sid)
  840. if info:
  841. # 切片修订状态
  842. revision_status = self._normalize_revision_status(info.get("revision_status"))
  843. # 记录状态-修订切片id-父子标题
  844. r["revision_status"] = revision_status
  845. r["ref_slice_id"] = self._normalize_slice_id(info.get("ref_slice_id"))
  846. r["section"] = info.get("section")
  847. r["parent_section"] = info.get("parent_section")
  848. # 原始->修订映射
  849. ref_id = r.get("ref_slice_id")
  850. # 如果是被修订的切片
  851. if revision_status == "1" and ref_id and sid:
  852. # 记录原始切片-修订切片的id映射
  853. original_to_revision_id[sid] = ref_id
  854. # 记录修订切片id
  855. revision_slice_ids.add(ref_id)
  856. # 记录后续使用
  857. chunk_revision_map[sid] = {"revised_to": ref_id}
  858. elif revision_status == "0":
  859. # 废弃切片
  860. chunk_revision_map[sid] = {"deprecated": True}
  861. else:
  862. r.setdefault("revision_status", None)
  863. r.setdefault("ref_slice_id", None)
  864. r.setdefault("section", None)
  865. r.setdefault("parent_section", None)
  866. # 查询修订切片的完整信息
  867. revision_slice_info_map = {}
  868. # 如果有被修订的切片
  869. if revision_slice_ids:
  870. # 查询修订切片信息
  871. success, revision_rows_or_error = mysql_client.query_slice_revision_info_by_slice_ids_any(list(revision_slice_ids))
  872. if success:
  873. # 构建修订id->row的映射
  874. revision_slice_info_map = {
  875. self._normalize_slice_id(r.get("slice_id")): r
  876. for r in revision_rows_or_error if r.get("slice_id")
  877. }
  878. logger.info(f"查询到 {len(revision_slice_info_map)} 个修订切片的完整信息")
  879. # 补充文档名称
  880. doc_ids = list({r.get("doc_id") for r in retriever_result_list if r.get("doc_id")})
  881. success, doc_name_map_or_error = mysql_client.query_document_names_by_document_ids(doc_ids)
  882. doc_name_map = doc_name_map_or_error if success else {}
  883. for r in retriever_result_list:
  884. doc_id = r.get("doc_id")
  885. doc_name = doc_name_map.get(doc_id)
  886. r["document_name"] = doc_name
  887. if r.get("metadata") is None:
  888. r["metadata"] = {}
  889. if doc_name:
  890. r["metadata"]["source"] = doc_name
  891. # 创建用于LLM的结果列表(使用修订后的切片,过滤废弃切片)
  892. llm_result_list = []
  893. deprecated_chunks = [] # 记录废弃切片,用于后续存储到Redis
  894. for r in retriever_result_list:
  895. sid = self._normalize_slice_id(r.get("chunk_id"))
  896. revision_status = r.get("revision_status")
  897. # 过滤废弃切片,不传给LLM
  898. if revision_status == "0":
  899. deprecated_chunks.append(r)
  900. logger.info(f"过滤废弃切片 {sid},不传给LLM")
  901. continue
  902. # 处理 revision_status="1" 但 ref_slice_id 为空的情况
  903. if revision_status == "1" and not r.get("ref_slice_id"):
  904. info = slice_info_map.get(sid)
  905. revision_text = info.get("revision_slice_text", "") if info else ""
  906. if not revision_text:
  907. revision_text = r.get("content", "")
  908. r["content"] = revision_text
  909. llm_result_list.append(r)
  910. logger.info(f"使用切片 {sid} 的 revision_slice_text(无映射关系)")
  911. continue
  912. # 如果这是一个有修订的原始切片,使用修订后的内容
  913. if sid in original_to_revision_id:
  914. revision_id = original_to_revision_id[sid]
  915. revision_info = revision_slice_info_map.get(revision_id)
  916. # 获取当前切片的 revision_slice_text
  917. info = slice_info_map.get(sid)
  918. revision_text = info.get("revision_slice_text", "") if info else ""
  919. if not revision_text:
  920. revision_text = r.get("content", "")
  921. if revision_info:
  922. # 创建修订切片的副本用于LLM,使用当前切片的 revision_slice_text
  923. revised_chunk = {
  924. "chunk_id": revision_id,
  925. "content": revision_text,
  926. "doc_id": revision_info.get("document_id", r.get("doc_id")),
  927. "document_name": doc_name_map.get(revision_info.get("document_id"), r.get("document_name")),
  928. "section": revision_info.get("section"),
  929. "parent_section": revision_info.get("parent_section"),
  930. "revision_status": None,
  931. "ref_slice_id": None,
  932. "metadata": r.get("metadata", {}),
  933. "rerank_score": r.get("rerank_score", 1),
  934. "is_revised_version": True,
  935. "original_chunk_id": sid
  936. }
  937. llm_result_list.append(revised_chunk)
  938. logger.info(f"使用原始切片 {sid} 的 revision_slice_text + 映射切片 {revision_id} 的元数据")
  939. else:
  940. # 如果没有查询到修订切片信息,使用当前切片的 revision_slice_text
  941. r["content"] = revision_text
  942. llm_result_list.append(r)
  943. logger.warning(f"未找到映射切片 {revision_id},使用原始切片 {sid} 的 revision_slice_text")
  944. else:
  945. # 没有修订的切片直接使用
  946. llm_result_list.append(r)
  947. logger.info(f"过滤后传给LLM的切片数: {len(llm_result_list)}, 废弃切片数: {len(deprecated_chunks)}")
  948. # 生成 LLM 拼接内容(使用修订后的切片)
  949. chunk_content_for_llm = self._build_llm_chunk_content(llm_result_list)
  950. return retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list
  951. async def generate_event(self, chat_json, request):
  952. logger.info(f"rag聊天的请求参数:{chat_json}")
  953. start_time = time.time()
  954. query = chat_json.get("messages")[-1].get("content")
  955. text = query
  956. hits = tokenizer_sensitive.detect(text)
  957. if hits:
  958. # print(build_warning(hits))
  959. yield {"id": "", "event": "add", "data": build_warning(hits)}
  960. return
  961. app_name = chat_json.get("name", "")
  962. label_prompt = """
  963. 你是一个 RAG 系统的输入判定器,只做判断,不生成答案。
  964. 请严格判断用户输入是否属于以下三类之一:
  965. 1. 语义模糊(NEED_CLARIFY):输入没有明确的主语、对象或范围,或者类似“是什么”“怎么样”“基本规范”“规定是什么”等,无法确定查询目标。
  966. 2. 明显拼写或术语错误(NEED_CORRECT):输入包含错别字、无意义组合、逻辑混乱或非标准术语,例如“天气怎么阳”“明天什么暗拍”“低功耗贾张氏设计标准”。
  967. 3. 输入清晰(OK):输入语义明确、包含实体、对象、拼写和术语正确,可直接用于检索和回答。
  968. 要求严格输出标准 JSON,只输出 JSON,不生成任何解释性文字。
  969. 输出 JSON 格式:
  970. {{
  971. "decision": "OK" 或 "NEED_CLARIFY" 或 "NEED_CORRECT" 或 "SENSITIVE",
  972. "reason": "一句话说明原因(可为空字符串)",
  973. "correct": "如果是NEED_CORRECT给我出纠正的内容,否则为空",
  974. "suggested_question": "如果需要反问或纠正,给出建议文本,否则为空字符串,如有此回答,尽量生动形象,增强用户体验",
  975. "summary": "如果有上下文,给出上下文摘要",
  976. "rewrite": "结合上下文重写用户问题,去除指代和模糊表达,使其可独立理解并适合知识检索"
  977. }}
  978. 注意事项:
  979. - 如果输入模糊或有错别字,必须返回对应类型并提供简短建议(suggested_question)。
  980. - 所有字段必须存在。
  981. - 输出必须严格遵守 JSON 格式。
  982. - 不允许有多余文字、换行或 Markdown。
  983. - 上下文重要性:
  984. - 如果用户问题较模糊但可依赖之前对话(如“之前问过xxx,接着问为什么?怎么做?”),判定为OK。
  985. - 根据最近对话内容,将当前问题改写为简短、独立、明确、适合检索的查询,输出到 rewrite 字段。
  986. - 如果有上下文,summary 字段应简明概括最近相关对话内容。
  987. - 示例:
  988. 输入:“是什么?” → NEED_CLARIFY
  989. 输入:“都会各有各的、是什么?” → NEED_CLARIFY
  990. 输入:“天气怎么阳” → NEED_CORRECT
  991. 输入:“明天的温度是多少?” → OK
  992. 输入上下文:
  993. Q: 明天的温度是多少?
  994. A: 最高温度 22°C,最低温度 14°C
  995. 输入:“为什么?” → OK,rewrite: “为什么明天温度变化幅度大?”
  996. """
  997. # - 通过上下文进行判断,假如上一个问题已经回答过,用户接着问“为什么?怎么做?”之类的同样判定为OK,并根据最近对话内容,将当前用户问题改写为一个不含指代词、对象明确、可独立理解、适合用于知识检索的查询,输出到rewrite字段。
  998. # 用户输入:
  999. # "{user_input}"
  1000. chat_json["query"] = query
  1001. import copy
  1002. label_messages = copy.deepcopy(chat_json.get("messages"))
  1003. # label_prompt = label_prompt.format(user_input=query)
  1004. # label_messages = chat_json.get("messages").copy()
  1005. # label_messages[-1]["content"] = label_prompt
  1006. label_messages.insert(0, {"role": "system","content": f"{label_prompt}"})
  1007. logger.info(f"追问识别的提示词:{label_messages}")
  1008. completion = await self.vllm_client.generate_non_stream_async(
  1009. prompt=label_messages,
  1010. model="/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  1011. )
  1012. raw_text = completion.strip()
  1013. logger.info(f"问题判定的结果:{raw_text}")
  1014. # 尝试解析 JSON
  1015. try:
  1016. data = json.loads(raw_text)
  1017. except json.JSONDecodeError:
  1018. logger.error("判定器返回的文本不是有效 JSON")
  1019. data = {"decision": "OK", "reason": "", "suggested_question": ""}
  1020. decision = data.get("decision", "OK")
  1021. summary = data.get("summary", "")
  1022. rewrite = data.get("rewrite", "")
  1023. if decision == "SENSITIVE":
  1024. yield {"id": "", "event": "add", "data": data.get("sensitive")}
  1025. return
  1026. if decision == "NEED_CORRECT":
  1027. new_query = data.get("correct", query)
  1028. chat_json["query"] = new_query
  1029. decision = "OK"
  1030. if decision != "OK":
  1031. # 使用app_name进行向量检索,生成反问
  1032. try:
  1033. collection_name_list = chat_json.get("knowledgeIds", [])
  1034. rag_embedding_name = chat_json.get("embeddingId")
  1035. if collection_name_list and rag_embedding_name:
  1036. # 使用原始query进行检索,获取相关上下文
  1037. search_results, _ = self.request_milvus(
  1038. collection_name_list=collection_name_list,
  1039. rag_embedding_name=rag_embedding_name,
  1040. query=query,
  1041. top=3, # 只检索3条相关内容
  1042. mode="hybrid"
  1043. )
  1044. # 提取检索到的内容
  1045. context_content = "\n".join([r.get("content", "") for r in search_results[:3]])
  1046. if context_content:
  1047. # 构建智能反问提示词
  1048. clarify_prompt = f"""你是一个RAG应用助手。用户的输入存在问题:{data.get('reason', '输入不够清晰')}
  1049. 用户输入:{query}
  1050. 知识库中检索到以下相关内容:
  1051. {context_content}
  1052. 请基于检索到的内容,用友好、引导性的语气向用户提出1-2个具体的问题,帮助用户明确他们的需求。要求:
  1053. 1. 问题要具体、有针对性
  1054. 2. 语气要友好、鼓励性
  1055. 3. 可以提供一些选项供用户选择
  1056. 4. 直接输出问题,不要有多余的解释
  1057. 示例格式:
  1058. "您可能想了解XXX方面的内容。请问您是想了解:
  1059. 1. XXX的具体定义和标准?
  1060. 2. XXX的应用场景和案例?
  1061. 还是有其他具体问题呢?"
  1062. """
  1063. rq_tokens_str = ""
  1064. # 使用流式chat接口进行反问
  1065. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  1066. full_response = ""
  1067. logger.info(f"反问的提示词是:{clarify_prompt}")
  1068. async for chunk in self.vllm_client.chat(
  1069. prompt=clarify_prompt,
  1070. model=model,
  1071. temperature=0.7,
  1072. top_p=0.9,
  1073. max_tokens=512,
  1074. stream=True,
  1075. history=[]
  1076. ):
  1077. if chunk.get("event") == "add":
  1078. # 累积完整响应
  1079. full_response += chunk.get("data", "")
  1080. rq_tokens_str = full_response
  1081. chunk["data"] = full_response
  1082. # 直接yield当前chunk,保持流式输出
  1083. yield {"id": chunk.get("id", ""), "event": "add", "data": chunk.get("data")}
  1084. elif chunk.get("event") == "finish":
  1085. yield {"id": chunk.get("id", ""), "event": "finish", "data": f'{{"inputToken": {self.count_tokens(query)}, "outputToken": {self.count_tokens(rq_tokens_str)}, "timeConsuming": {time.time() - start_time}}}'}
  1086. logger.info(f"反问生成完成,完整响应: {full_response}")
  1087. return
  1088. except Exception as e:
  1089. logger.error(f"反问生成失败,使用默认回复: {e}")
  1090. # 如果检索失败或没有配置知识库,使用原始的suggested_question
  1091. reply = data.get("suggested_question", "抱歉,我没有理解您的问题。能否请您更详细地描述一下?")
  1092. yield {"id": "", "event": "add", "data": reply}
  1093. yield {"id": "", "event": "finish", "data": ""}
  1094. return
  1095. if decision == "OK":
  1096. chat_id = ""
  1097. SENTINEL = "⧗⧗LLM_STREAM_SENTINEL⧗⧗"
  1098. try:
  1099. output_token_str = ""
  1100. tokens = self.count_tokens(chat_json.get("query") + chat_json.get("prompt"))
  1101. logger.info(f"用户输入的token数:{tokens}")
  1102. retriever_result_list, search_doc_id_to_knowledge_id_dict = await self.retriever_result(chat_json, rewrite)
  1103. logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  1104. retriever_result_list, search_doc_id_to_knowledge_id_dict, chunk_content_for_llm, chunk_revision_map, llm_result_list = self._apply_revision_and_doc_name(
  1105. chat_json,
  1106. retriever_result_list,
  1107. search_doc_id_to_knowledge_id_dict,
  1108. )
  1109. # 添加引用指令到提示词
  1110. citation_instruction = f"\n\n《额外指令》\n\n1、chunk_id仅供正文结束后使用,禁止在正文中以任何形式输出任何chunk_id\n\n2、在回答完成后,请按照以下格式列出你引用的所有chunk_id(必须输出,无引用cited_chunks可返回空字符串),格式为:{SENTINEL}\n[cited_chunks: id1, id2, id3]"
  1111. chat_json["prompt"] = chat_json.get("prompt", "") + citation_instruction
  1112. first = True
  1113. sentinel_detected = False
  1114. buffering = False
  1115. buffer_content = ""
  1116. clean_output = ""
  1117. end_time = None
  1118. async for event in self.generate_rag_response(chat_json, chunk_content_for_llm):
  1119. chat_id = event.get("id")
  1120. if first:
  1121. end_time = time.time()
  1122. first = False
  1123. if await request.is_disconnected():
  1124. logger.info(f"chat id:{chat_id}连接中断")
  1125. yield {"id": chat_id, "event": "interrupted", "data": ""}
  1126. return
  1127. if event.get("event") == "add":
  1128. full_data = event.get("data", "")
  1129. output_token_str = full_data
  1130. # 如果已检测到完整哨兵线,收集后续内容用于解析chunk_ids
  1131. if sentinel_detected:
  1132. # re.IGNORECASE不区分大小写
  1133. cited_match = re.search(r'\[cited_chunks[::]\s*([^\]]+)\]', full_data, re.IGNORECASE)
  1134. # if not cited_match:
  1135. # sentinel_detected = False
  1136. # continue
  1137. if cited_match:
  1138. cited_ids_str = cited_match.group(1)
  1139. cited_chunk_ids = [cid.strip().upper() for cid in cited_ids_str.split(',')]
  1140. logger.info(f"解析到LLM引用的chunk_ids: {cited_chunk_ids}")
  1141. # 构建包含引用块的knowledge_info_dict,同时独立保存修订切片和原始切片
  1142. knowledge_info_dict = {}
  1143. # 创建chunk_id映射,用于快速查找(使用小写作为key以避免大小写问题)
  1144. llm_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in llm_result_list}
  1145. original_chunk_map = {r.get("chunk_id").lower() if r.get("chunk_id") else None: r for r in retriever_result_list}
  1146. # 记录日志,调试
  1147. logger.info(f"llm_result_list中的chunk_ids: {[r.get('chunk_id') for r in llm_result_list]}")
  1148. logger.info(f"retriever_result_list中的chunk_ids: {[r.get('chunk_id') for r in retriever_result_list]}")
  1149. for cited_id in cited_chunk_ids:
  1150. cited_id_normalized = cited_id.strip().upper()
  1151. # 在LLM结果列表中查找(可能是修订版本)- 使用小写匹配
  1152. llm_chunk = llm_chunk_map.get(cited_id_normalized.lower())
  1153. if not llm_chunk:
  1154. logger.warning(f"未找到LLM引用的chunk_id: {cited_id_normalized}")
  1155. sentinel_detected = False
  1156. continue
  1157. doc_id = llm_chunk.get("doc_id")
  1158. knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id)
  1159. # 判断这个切片是否是修订版本
  1160. is_revised_version = llm_chunk.get("is_revised_version", False)
  1161. # 构建当前切片的chunk_info
  1162. current_chunk_info = {
  1163. "chunk_id": llm_chunk.get("chunk_id"),
  1164. "content": llm_chunk.get("content", ""),
  1165. "rerank_score": llm_chunk.get("rerank_score"),
  1166. "chunk_len": len(llm_chunk.get("content", "")),
  1167. "document_name": llm_chunk.get("document_name"),
  1168. "revision_status": llm_chunk.get("revision_status"),
  1169. "ref_slice_id": llm_chunk.get("ref_slice_id"),
  1170. "section": llm_chunk.get("section"),
  1171. "parent_section": llm_chunk.get("parent_section"),
  1172. }
  1173. # 初始化chunk_info_list
  1174. chunk_info_list = []
  1175. # 如果这是修订版本,需要同时保存修订切片和原始切片
  1176. if is_revised_version:
  1177. original_chunk_id = llm_chunk.get("original_chunk_id")
  1178. # 使用小写查找原始切片
  1179. original_chunk = original_chunk_map.get(original_chunk_id.lower() if original_chunk_id else None)
  1180. logger.info(f"查找原始切片: original_chunk_id={original_chunk_id}, found={original_chunk is not None}")
  1181. if original_chunk:
  1182. # 1. 先添加修订切片信息
  1183. current_chunk_info["slice_type"] = "revised"
  1184. chunk_info_list.append(current_chunk_info)
  1185. # 2. 再添加原始切片信息
  1186. original_chunk_info = {
  1187. "chunk_id": original_chunk.get("chunk_id"),
  1188. "content": original_chunk.get("content", ""),
  1189. "rerank_score": original_chunk.get("rerank_score"),
  1190. "chunk_len": len(original_chunk.get("content", "")),
  1191. "document_name": original_chunk.get("document_name"),
  1192. "revision_status": original_chunk.get("revision_status"),
  1193. "ref_slice_id": original_chunk.get("ref_slice_id"), # 保留数据库原始字段,指向修订切片
  1194. "section": original_chunk.get("section"),
  1195. "parent_section": original_chunk.get("parent_section"),
  1196. "slice_type": "original" # 标记为原始切片
  1197. }
  1198. # 添加 revised_to_chunk_id 和 chunk_revision_map 的内容(废弃),因为 ref_slice_id 已经包含了指向修订切片的信息
  1199. chunk_info_list.append(original_chunk_info)
  1200. logger.info(f"独立保存修订切片 {llm_chunk.get('chunk_id')} 和原始切片 {original_chunk_id} 信息")
  1201. else:
  1202. # 没有找到原始切片,只保存修订切片
  1203. current_chunk_info["slice_type"] = "revised"
  1204. chunk_info_list.append(current_chunk_info)
  1205. logger.warning(f"未找到修订切片 {llm_chunk.get('chunk_id')} 对应的原始切片 {original_chunk_id}")
  1206. else:
  1207. # 非修订切片,检查是否有废弃切片需要一起存储
  1208. # 在retriever_result_list中查找是否有revision_status="0"的切片
  1209. chunk_in_retriever = original_chunk_map.get(cited_id_normalized.lower())
  1210. if chunk_in_retriever and chunk_in_retriever.get("revision_status") == "0":
  1211. # 这是一个废弃切片,标记为deprecated
  1212. current_chunk_info["slice_type"] = "deprecated"
  1213. logger.info(f"存储废弃切片 {cited_id_normalized} 到Redis")
  1214. else:
  1215. # 普通切片
  1216. current_chunk_info["slice_type"] = "normal"
  1217. chunk_info_list.append(current_chunk_info)
  1218. doc_info = {
  1219. "doc_id": doc_id,
  1220. "doc_name": llm_chunk.get("document_name"),
  1221. "chunk_nums": len(chunk_info_list),
  1222. "chunk_info_list": chunk_info_list
  1223. }
  1224. if knowledge_id in knowledge_info_dict:
  1225. existing_doc = None
  1226. for doc in knowledge_info_dict[knowledge_id]:
  1227. if doc["doc_id"] == doc_id:
  1228. existing_doc = doc
  1229. break
  1230. if existing_doc:
  1231. existing_doc["chunk_info_list"].extend(chunk_info_list)
  1232. existing_doc["chunk_nums"] += len(chunk_info_list)
  1233. else:
  1234. knowledge_info_dict[knowledge_id].append(doc_info)
  1235. else:
  1236. knowledge_info_dict[knowledge_id] = [doc_info]
  1237. # 额外处理:将所有废弃切片也存储到Redis(即使LLM没有引用)
  1238. for r in retriever_result_list:
  1239. if r.get("revision_status") == "0":
  1240. deprecated_chunk_id = r.get("chunk_id")
  1241. doc_id = r.get("doc_id")
  1242. knowledge_id = search_doc_id_to_knowledge_id_dict.get(doc_id)
  1243. # 检查是否已经存储过这个废弃切片
  1244. already_stored = False
  1245. if knowledge_id in knowledge_info_dict:
  1246. for doc in knowledge_info_dict[knowledge_id]:
  1247. if doc["doc_id"] == doc_id:
  1248. for chunk_info in doc["chunk_info_list"]:
  1249. if chunk_info["chunk_id"] == deprecated_chunk_id:
  1250. already_stored = True
  1251. break
  1252. if already_stored:
  1253. break
  1254. if not already_stored:
  1255. # 构建废弃切片信息
  1256. deprecated_chunk_info = {
  1257. "chunk_id": deprecated_chunk_id,
  1258. "content": r.get("content", ""),
  1259. "rerank_score": r.get("rerank_score"),
  1260. "chunk_len": len(r.get("content", "")),
  1261. "document_name": r.get("document_name"),
  1262. "revision_status": "0", # 明确标记为废弃
  1263. "ref_slice_id": r.get("ref_slice_id"),
  1264. "section": r.get("section"),
  1265. "parent_section": r.get("parent_section"),
  1266. "slice_type": "deprecated"
  1267. }
  1268. doc_info = {
  1269. "doc_id": doc_id,
  1270. "doc_name": r.get("document_name"),
  1271. "chunk_nums": 1,
  1272. "chunk_info_list": [deprecated_chunk_info]
  1273. }
  1274. if knowledge_id in knowledge_info_dict:
  1275. existing_doc = None
  1276. for doc in knowledge_info_dict[knowledge_id]:
  1277. if doc["doc_id"] == doc_id:
  1278. existing_doc = doc
  1279. break
  1280. if existing_doc:
  1281. existing_doc["chunk_info_list"].append(deprecated_chunk_info)
  1282. existing_doc["chunk_nums"] += 1
  1283. else:
  1284. knowledge_info_dict[knowledge_id].append(doc_info)
  1285. else:
  1286. knowledge_info_dict[knowledge_id] = [doc_info]
  1287. logger.info(f"额外存储废弃切片 {deprecated_chunk_id} 到Redis")
  1288. # 存储到Redis
  1289. self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
  1290. logger.info(f"存储引用的块到Redis: {len(cited_chunk_ids)}个(包含修订、原始和废弃切片信息)")
  1291. # 返回finish事件
  1292. output_token = self.count_tokens(clean_output)
  1293. time_out = end_time - start_time if end_time else 0
  1294. finish_data = {
  1295. "inputToken": tokens,
  1296. "outputToken": output_token,
  1297. "timeConsuming": time_out
  1298. }
  1299. # yield {"id": chat_id, "event": "finish", "data": '你猜!'}
  1300. logger.info(f"111111111111111111111111111:{event}")
  1301. event["event"] = "finish"
  1302. event["data"] = finish_data
  1303. # yield event
  1304. send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'}
  1305. yield send
  1306. # yield {"id": chat_id, "event": "finish", "inputToken": tokens, "outputToken": output_token, "timeConsuming": time_out}
  1307. logger.info(f"{send}")
  1308. return # 直接返回,不再继续处理后续事件
  1309. continue
  1310. # 检查是否包含完整哨兵线
  1311. if SENTINEL in full_data:
  1312. sentinel_pos = full_data.find(SENTINEL)
  1313. sentinel_detected = True
  1314. clean_output = full_data[:sentinel_pos]
  1315. logger.info(f"检测到完整哨兵线")
  1316. continue
  1317. # 检查是否包含哨兵线的第一个字符 '⧗'
  1318. if not buffering and '⧗' in full_data:
  1319. first_char_pos = full_data.find('⧗')
  1320. # 发送 '⧗' 之前的内容
  1321. if first_char_pos > 0:
  1322. event["data"] = full_data[:first_char_pos]
  1323. yield event
  1324. buffering = True
  1325. buffer_content = full_data[first_char_pos:]
  1326. clean_output = full_data[:first_char_pos]
  1327. logger.info(f"检测到哨兵线第一个字符,开始缓冲")
  1328. continue
  1329. # 如果正在缓冲
  1330. if buffering:
  1331. buffer_content = full_data[len(clean_output):]
  1332. if SENTINEL.startswith(buffer_content):
  1333. # 可能是哨兵线的一部分,继续缓冲
  1334. continue
  1335. elif event.get("event") != "finish":
  1336. # 不是哨兵线,发送缓冲的内容
  1337. buffering = False
  1338. yield event
  1339. elif event.get("event") != "finish":
  1340. # 正常返回
  1341. yield event
  1342. elif event.get("event") == "finish":
  1343. # 如果已经通过哨兵线发送过finish事件,跳过这个finish事件
  1344. if sentinel_detected:
  1345. logger.info("哨兵线已处理,跳过LLM的finish事件")
  1346. continue
  1347. # 如果LLM提前结束但没有哨兵线,使用所有块
  1348. _chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict)
  1349. self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
  1350. logger.warning("LLM未输出哨兵线,存储所有检索块")
  1351. # 发送finish事件(LLM未输出哨兵线的情况)
  1352. output_token = self.count_tokens(output_token_str)
  1353. time_out = end_time - start_time if end_time else 0
  1354. finish_data = {
  1355. "inputToken": tokens,
  1356. "outputToken": output_token,
  1357. "timeConsuming": time_out
  1358. }
  1359. event["data"] = finish_data
  1360. send = {"id": chat_id, "event": "finish", "data": f'{{"inputToken": {tokens}, "outputToken": {output_token}, "timeConsuming": {time_out}}}'}
  1361. yield send
  1362. logger.info(f"LLM未输出哨兵线,发送finish事件: {send}")
  1363. elif event.get("event") != "finish":
  1364. yield event
  1365. except Exception as e:
  1366. logger.error(f"执行出错:{e}")
  1367. yield {"id": chat_id, "event": "finish", "data": ""}
  1368. return
  1369. async def query_redis_chunk_by_chat_id(self, chat_id):
  1370. chunk_redis_str = self.redis_client.get(chat_id)
  1371. chunk_json = json.loads(chunk_redis_str)
  1372. logger.info(f"redis中存储的chunk 信息:{chunk_json}")
  1373. #knowledge_id = chunk_json.get("knowledge_id")
  1374. knowledge_id = list(chunk_json.keys())[0] if chunk_json else None
  1375. chunk_list = []
  1376. #for chunk in chunk_json["doc"]:
  1377. # doc_chunk_list = chunk.get("chunk_info_list")
  1378. # chunk_list.extend(doc_chunk_list)
  1379. for knowledge_id_key, doc_list in chunk_json.items():
  1380. for doc in doc_list:
  1381. doc_chunk_list = doc.get("chunk_info_list")
  1382. if doc_chunk_list:
  1383. chunk_list.extend(doc_chunk_list)
  1384. if len(chunk_list) > 3:
  1385. random_chunk_list = random.sample(chunk_list, 3)
  1386. else:
  1387. random_chunk_list = chunk_list
  1388. chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
  1389. return knowledge_id, chunk_id_list
  1390. async def generate_relevant_query(self, query_json):
  1391. chat_id = query_json.get("chat_id")
  1392. if chat_id:
  1393. knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
  1394. chunk_content_list = MilvusOperate(collection_name=knowledge_id, embedding_name=query_json.get("embedding_id"))._search_by_chunk_id_list(chunk_id_list)
  1395. chunk_content = "\n".join(chunk_content_list)
  1396. # generate_query_prompt = "请根据下面提供的信息,帮我生成"
  1397. query_messages = query_json.get("messages")
  1398. user_dict = query_messages.pop()
  1399. messages = [{"role": "assistant", "content": chunk_content}, user_dict]
  1400. else:
  1401. messages = query_json.get("messages")
  1402. #model = query_json.get("model")
  1403. model = "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507"
  1404. query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
  1405. async for result in query_result:
  1406. # result_json = json.loads(result)
  1407. # logger.info(f"生成的问题:{result_json}")
  1408. # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
  1409. result = result.get("data", "").strip()
  1410. logger.info(f"模型生成的问题:{result}")
  1411. try:
  1412. if "```json" in result:
  1413. json_pattern = r'```json\s(.*?)```'
  1414. matches = re.findall(json_pattern, result, re.DOTALL)
  1415. result = matches[0]
  1416. query_json = json.loads(result)
  1417. except Exception as e:
  1418. query_json = eval(result)
  1419. query_list = query_json.get("问题")
  1420. return {"code": 200, "data": query_list}
  1421. async def search_slice(self):
  1422. try:
  1423. chunk_redis_str = self.redis_client.get(self.chat_id)
  1424. chunk_json = json.loads(chunk_redis_str)
  1425. chunk_json["code"] = 200
  1426. except Exception as e:
  1427. logger.error(f"查询redis报错:{e}")
  1428. chunk_json = {
  1429. "code": 500,
  1430. "message": str(e)
  1431. }
  1432. return chunk_json