chat_message.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. from rag.db import MilvusOperate
  2. from rag.load_model import *
  3. from rag.llm import VllmApi
  4. from config import redis_config_dict
  5. import json
  6. import re
  7. import gc
  8. import redis
  9. from utils.get_logger import setup_logger
  10. import random
  11. import concurrent.futures
  12. logger = setup_logger(__name__)
  13. rerank_model_mapping = {
  14. "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
  15. "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
  16. }
  17. redis_host = redis_config_dict.get("host")
  18. redis_port = redis_config_dict.get("port")
  19. redis_db = redis_config_dict.get("db")
  20. class ChatRetrieverRag:
  21. def __init__(self, chat_json: dict = None, chat_id: str=None):
  22. self.chat_id = chat_id
  23. self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
  24. if not chat_id:
  25. self.vllm_client = VllmApi(chat_json)
  26. def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
  27. rerank_list = []
  28. for result in hybrid_search_result:
  29. rerank_list.append([query, result["content"]])
  30. tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
  31. # 重排序
  32. rerank_model.eval()
  33. inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
  34. inputs = {k: v.to(device) for k, v in inputs.items()}
  35. with torch.no_grad():
  36. logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
  37. logits = logits.detach().cpu()
  38. scores = torch.sigmoid(logits).tolist()
  39. # scores_list = scores.tolist()
  40. logger.info(f"重排序的得分:{scores}")
  41. sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
  42. sorted_scores, sorted_search = zip(*sorted_pairs)
  43. sorted_scores_list = list(sorted_scores)
  44. sorted_search_list = list(sorted_search)
  45. for score, search in zip(sorted_scores_list, sorted_search_list):
  46. search["rerank_score"] = score
  47. del inputs, logits
  48. gc.collect()
  49. torch.cuda.empty_cache()
  50. search_result = random.sample(sorted_search_list, top_k)
  51. return search_result
  52. def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode):
  53. search_multi_collection_list = []
  54. search_doc_id_to_knowledge_id_dict = {}
  55. def query_collection(collection_name):
  56. return collection_name, MilvusOperate(
  57. collection_name=collection_name,
  58. embedding_name=rag_embedding_name
  59. )._search(query, k=top, mode=mode)
  60. max_workers = min(10, len(collection_name_list))
  61. with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
  62. future_to_name = {
  63. executor.submit(query_collection, name): name
  64. for name in collection_name_list
  65. }
  66. for future in concurrent.futures.as_completed(future_to_name):
  67. collection_name = future_to_name[future]
  68. try:
  69. _, hybrid_search_result = future.result()
  70. for result in hybrid_search_result:
  71. doc_id = result.get("doc_id")
  72. if doc_id not in search_doc_id_to_knowledge_id_dict:
  73. search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
  74. search_multi_collection_list.extend(hybrid_search_result)
  75. except Exception as e:
  76. logger.error(f"查询集合 {collection_name} 时出错: {str(e)}")
  77. continue
  78. return search_multi_collection_list, search_doc_id_to_knowledge_id_dict
  79. def retriever_result(self, chat_json):
  80. mode_search_mode = {
  81. "embedding": "dense",
  82. "keyword": "sparse",
  83. "mixed": "hybrid",
  84. }
  85. collection_name_list = chat_json.get("knowledgeIds") # 如果指定多个knowledge传list
  86. top = chat_json.get("sliceCount", 5)
  87. retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
  88. rag_embedding_name = chat_json.get("embeddingId")
  89. mode_rag = retriever_info.get("recall_method")
  90. mode = mode_search_mode.get(mode_rag, "hybrid")
  91. rerank_embedding_name = retriever_info.get("rerank_model_name")
  92. rerank_status = retriever_info.get("rerank_status")
  93. query = chat_json.get("query")
  94. # search_multi_collection_list = []
  95. # search_doc_id_to_knowledge_id_dict = {}
  96. # for collection_name in collection_name_list:
  97. # hybrid_search_result = MilvusOperate(
  98. # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
  99. # for result in hybrid_search_result:
  100. # doc_id = result.get("doc_id")
  101. # if doc_id not in search_doc_id_to_knowledge_id_dict:
  102. # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
  103. # search_multi_collection_list.extend(hybrid_search_result)
  104. search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus(collection_name_list, rag_embedding_name, query, top, mode)
  105. # logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}")
  106. if len(search_multi_collection_list) <= 0:
  107. rerank_result_list = []
  108. elif rerank_status:
  109. rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name)
  110. else:
  111. for result in search_multi_collection_list:
  112. result["rerank_score"] = 1
  113. # rerank_result_list = search_multi_collection_list
  114. rerank_result_list = random.sample(search_multi_collection_list, top)
  115. return rerank_result_list, search_doc_id_to_knowledge_id_dict
  116. async def generate_sync_response(self, chat_json):
  117. """非流式生成响应"""
  118. try:
  119. logger.info(f"非流式rag聊天的请求参数:{chat_json}")
  120. # 1. 检索相关文档
  121. retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json)
  122. logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  123. # 2. 解析检索结果
  124. chunk_content, knowledge_info_dict = self.parse_retriever_list(
  125. retriever_result_list,
  126. search_doc_id_to_knowledge_id_dict
  127. )
  128. # 3. 构建提示词
  129. prompt = chat_json.get("prompt")
  130. query = chat_json.get("query")
  131. prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
  132. logger.info(f"请求的提示词:{prompt}")
  133. # 4. 调用LLM
  134. temperature = float(chat_json.get("temperature", 0.6))
  135. top_p = float(chat_json.get("topP", 0.7))
  136. max_token = chat_json.get("maxToken", 4096)
  137. enable_think = chat_json.get("enable_think", False)
  138. model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
  139. # 非流式调用
  140. chat_resp = ""
  141. async for chunk in self.vllm_client.chat(
  142. prompt,
  143. model,
  144. temperature=temperature,
  145. top_p=top_p,
  146. max_tokens=max_token,
  147. stream=False, # 设置为False非流式
  148. history=[],
  149. enable_think=enable_think
  150. ):
  151. chat_resp = chunk.choices[0].message.content # 非流式只返回一次完整内容
  152. # 记录检索数据
  153. chat_id = chunk.get("id")
  154. self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
  155. logger.info(f"LLM生成的回答:{chat_resp[:100]}...")
  156. # 5. 返回结果
  157. return {
  158. "code": 200,
  159. "data": {
  160. "answer": chat_resp,
  161. "chat_id": chat_id
  162. }
  163. }
  164. except Exception as e:
  165. logger.error(f"非流式RAG聊天出错:{e}")
  166. return {
  167. "code": 500,
  168. "message": str(e)
  169. }
  170. async def generate_rag_response(self, chat_json, chunk_content):
  171. # logger.info(f"rag聊天的请求参数:{chat_json}")
  172. # retriever_result_list = self.retriever_result(chat_json)
  173. # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  174. prompt = chat_json.get("prompt")
  175. query = chat_json.get("query")
  176. # chunk_content = ""
  177. # for retriever in retriever_result_list:
  178. # chunk_content += retriever["content"]
  179. prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
  180. # logger.info(f"请求的提示词:{prompt}")
  181. temperature = float(chat_json.get("temperature", 0.6))
  182. top_p = float(chat_json.get("topP", 0.7))
  183. max_token = chat_json.get("maxToken", 4096)
  184. enable_think = chat_json.get("enable_think", False)
  185. # 调用模型获取返回的结果
  186. model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
  187. # model = "Qwen3-Coder-30B-loft"
  188. chat_resp = ""
  189. async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[], enable_think=enable_think):
  190. if chunk.get("event") == "add":
  191. chat_resp += chunk.get("data")
  192. chunk["data"] = chat_resp
  193. yield chunk
  194. else:
  195. yield chunk
  196. def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict):
  197. doc_info = {}
  198. chunk_content = ""
  199. # 组织成每个doc对应的json格式
  200. for retriever in retriever_result_list:
  201. chunk_text = retriever["content"]
  202. chunk_content += chunk_text
  203. doc_id = retriever["doc_id"]
  204. chunk_id = retriever["chunk_id"]
  205. rerank_score = retriever["rerank_score"]
  206. doc_name = retriever["metadata"]["source"]
  207. if doc_id in doc_info:
  208. c_d = {
  209. "chunk_id": chunk_id,
  210. "rerank_score": rerank_score,
  211. "chunk_len": len(chunk_text)
  212. }
  213. doc_info[doc_id]["chunk_list"].append(c_d)
  214. # doc_info[doc_id]["rerank_score"].append(rerank_score)
  215. else:
  216. c_d = {
  217. "chunk_id": chunk_id,
  218. "rerank_score": rerank_score,
  219. "chunk_len": len(chunk_text)
  220. }
  221. doc_info[doc_id] = {
  222. "doc_name": doc_name,
  223. "chunk_list": [c_d],
  224. }
  225. # doc_list = []
  226. knowledge_info_dict = {}
  227. """
  228. {
  229. "knowledge_id": [{}, {}],
  230. "knowledge_id"
  231. }
  232. """
  233. for k, v in doc_info.items():
  234. d = {}
  235. knowledge_id = search_doc_id_to_knowledge_id_dict.get(k)
  236. d["doc_id"] = k
  237. d["doc_name"] = v.get("doc_name")
  238. d["chunk_nums"] = len(v.get("chunk_list"))
  239. d["chunk_info_list"] = v.get("chunk_list")
  240. # d["chunk_len"] = len(v.get("chunk_id_list"))
  241. if knowledge_id in knowledge_info_dict:
  242. knowledge_info_dict[knowledge_id].append(d)
  243. else:
  244. knowledge_info_dict[knowledge_id] = [d]
  245. # doc_list.append(d)
  246. return chunk_content, knowledge_info_dict
  247. async def generate_event(self, chat_json, request):
  248. chat_id = ""
  249. try:
  250. logger.info(f"rag聊天的请求参数:{chat_json}")
  251. # knowledge_id = chat_json.get("knowledgeIds")
  252. retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json)
  253. logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  254. chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict)
  255. first = True
  256. async for event in self.generate_rag_response(chat_json, chunk_content):
  257. chat_id = event.get("id")
  258. if first:
  259. #json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
  260. self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
  261. # logger.info(f"返回检索出的切片信息:{json_dict}")
  262. # yield {"id": chat_id, "event": "json", "data": json_dict}
  263. first = False
  264. yield event
  265. # yield json.dumps(event, ensure_ascii=False)
  266. if await request.is_disconnected():
  267. logger.info(f"chat id:{chat_id}连接中断")
  268. yield {"id": chat_id, "event": "interrupted", "data": ""}
  269. return
  270. # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
  271. except Exception as e:
  272. logger.info(f"执行出错:{e}")
  273. yield {"id": chat_id, "event": "finish", "data": ""}
  274. return
  275. # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
  276. async def query_redis_chunk_by_chat_id(self, chat_id):
  277. chunk_redis_str = self.redis_client.get(chat_id)
  278. chunk_json = json.loads(chunk_redis_str)
  279. logger.info(f"redis中存储的chunk 信息:{chunk_json}")
  280. #knowledge_id = chunk_json.get("knowledge_id")
  281. knowledge_id = list(chunk_json.keys())[0] if chunk_json else None
  282. chunk_list = []
  283. #for chunk in chunk_json["doc"]:
  284. # doc_chunk_list = chunk.get("chunk_info_list")
  285. # chunk_list.extend(doc_chunk_list)
  286. for knowledge_id_key, doc_list in chunk_json.items():
  287. for doc in doc_list:
  288. doc_chunk_list = doc.get("chunk_info_list")
  289. if doc_chunk_list:
  290. chunk_list.extend(doc_chunk_list)
  291. if len(chunk_list) > 3:
  292. random_chunk_list = random.sample(chunk_list, 3)
  293. else:
  294. random_chunk_list = chunk_list
  295. chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
  296. return knowledge_id, chunk_id_list
  297. async def generate_relevant_query(self, query_json):
  298. chat_id = query_json.get("chat_id")
  299. if chat_id:
  300. knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
  301. chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
  302. chunk_content = "\n".join(chunk_content_list)
  303. # generate_query_prompt = "请根据下面提供的信息,帮我生成"
  304. query_messages = query_json.get("messages")
  305. user_dict = query_messages.pop()
  306. messages = [{"role": "assistant", "content": chunk_content}, user_dict]
  307. else:
  308. messages = query_json.get("messages")
  309. model = query_json.get("model")
  310. # model = "Qwen3-Coder-30B-loft"
  311. query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
  312. async for result in query_result:
  313. # result_json = json.loads(result)
  314. # logger.info(f"生成的问题:{result_json}")
  315. # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
  316. result = result.strip()
  317. logger.info(f"模型生成的问题:{result}")
  318. try:
  319. if "```json" in result:
  320. json_pattern = r'```json\s(.*?)```'
  321. matches = re.findall(json_pattern, result, re.DOTALL)
  322. result = matches[0]
  323. query_json = json.loads(result)
  324. except Exception as e:
  325. query_json = eval(result)
  326. query_list = query_json.get("问题")
  327. return {"code": 200, "data": query_list}
  328. async def search_slice(self):
  329. try:
  330. chunk_redis_str = self.redis_client.get(self.chat_id)
  331. chunk_json = json.loads(chunk_redis_str)
  332. chunk_json["code"] = 200
  333. except Exception as e:
  334. logger.error(f"查询redis报错:{e}")
  335. chunk_json = {
  336. "code": 500,
  337. "message": str(e)
  338. }
  339. return chunk_json