chat_message.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from rag.db import MilvusOperate
  2. from rag.load_model import *
  3. from rag.llm import VllmApi
  4. from config import redis_config_dict
  5. import json
  6. import re
  7. import gc
  8. import redis
  9. from utils.get_logger import setup_logger
  10. logger = setup_logger(__name__)
  11. rerank_model_mapping = {
  12. "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
  13. "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
  14. }
  15. redis_host = redis_config_dict.get("host")
  16. redis_port = redis_config_dict.get("port")
  17. redis_db = redis_config_dict.get("db")
  18. class ChatRetrieverRag:
  19. def __init__(self, chat_json: dict = None, chat_id: str=None):
  20. self.chat_id = chat_id
  21. self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
  22. if not chat_id:
  23. self.vllm_client = VllmApi(chat_json)
  24. def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
  25. rerank_list = []
  26. for result in hybrid_search_result:
  27. rerank_list.append([query, result["content"]])
  28. tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
  29. # 重排序
  30. rerank_model.eval()
  31. inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
  32. inputs = {k: v.to(device) for k, v in inputs.items()}
  33. with torch.no_grad():
  34. logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
  35. logits = logits.detach().cpu()
  36. scores = torch.sigmoid(logits).tolist()
  37. # scores_list = scores.tolist()
  38. logger.info(f"重排序的得分:{scores}")
  39. sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
  40. sorted_scores, sorted_search = zip(*sorted_pairs)
  41. sorted_scores_list = list(sorted_scores)
  42. sorted_search_list = list(sorted_search)
  43. for score, search in zip(sorted_scores_list, sorted_search_list):
  44. search["rerank_score"] = score
  45. del inputs, logits
  46. gc.collect()
  47. torch.cuda.empty_cache()
  48. return sorted_search_list
  49. def retriever_result(self, chat_json):
  50. mode_search_mode = {
  51. "embedding": "dense",
  52. "keyword": "sparse",
  53. "mixed": "hybrid",
  54. }
  55. collection_name = chat_json.get("knowledgeIds") #
  56. top = chat_json.get("sliceCount", 5)
  57. retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
  58. rag_embedding_name = chat_json.get("embeddingId")
  59. mode_rag = retriever_info.get("recall_method")
  60. mode = mode_search_mode.get(mode_rag, "hybrid")
  61. rerank_embedding_name = retriever_info.get("rerank_model_name")
  62. rerank_status = retriever_info.get("rerank_status")
  63. query = chat_json.get("query")
  64. hybrid_search_result = MilvusOperate(
  65. collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
  66. logger.info(f"根据{collection_name}检索到结果:{hybrid_search_result}")
  67. if len(hybrid_search_result) <= 0:
  68. rerank_result_list = []
  69. elif rerank_status:
  70. rerank_result_list = self.rerank_result(query, top, hybrid_search_result, rerank_embedding_name)
  71. else:
  72. for result in hybrid_search_result:
  73. result["rerank_score"] = 1
  74. rerank_result_list = hybrid_search_result
  75. return rerank_result_list
  76. async def generate_rag_response(self, chat_json, chunk_content):
  77. # logger.info(f"rag聊天的请求参数:{chat_json}")
  78. # retriever_result_list = self.retriever_result(chat_json)
  79. # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  80. prompt = chat_json.get("prompt")
  81. query = chat_json.get("query")
  82. # chunk_content = ""
  83. # for retriever in retriever_result_list:
  84. # chunk_content += retriever["content"]
  85. prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
  86. logger.info(f"请求的提示词:{prompt}")
  87. temperature = float(chat_json.get("temperature", 0.6))
  88. top_p = float(chat_json.get("topP", 0.7))
  89. max_token = chat_json.get("maxToken", 4096)
  90. # 调用模型获取返回的结果
  91. model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
  92. # model = "DeepSeek-R1-Distill-Qwen-14B"
  93. chat_resp = ""
  94. for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[]):
  95. if chunk.get("event") == "add":
  96. chat_resp += chunk.get("data")
  97. chunk["data"] = chat_resp
  98. yield chunk
  99. else:
  100. yield chunk
  101. def parse_retriever_list(self, retriever_result_list):
  102. doc_info = {}
  103. chunk_content = ""
  104. # 组织成每个doc对应的json格式
  105. for retriever in retriever_result_list:
  106. chunk_text = retriever["content"]
  107. chunk_content += chunk_text
  108. doc_id = retriever["doc_id"]
  109. chunk_id = retriever["chunk_id"]
  110. rerank_score = retriever["rerank_score"]
  111. doc_name = retriever["metadata"]["source"]
  112. if doc_id in doc_info:
  113. c_d = {
  114. "chunk_id": chunk_id,
  115. "rerank_score": rerank_score,
  116. "chunk_len": len(chunk_text)
  117. }
  118. doc_info[doc_id]["chunk_list"].append(c_d)
  119. # doc_info[doc_id]["rerank_score"].append(rerank_score)
  120. else:
  121. c_d = {
  122. "chunk_id": chunk_id,
  123. "rerank_score": rerank_score,
  124. "chunk_len": len(chunk_text)
  125. }
  126. doc_info[doc_id] = {
  127. "doc_name": doc_name,
  128. "chunk_list": [c_d],
  129. }
  130. doc_list = []
  131. for k, v in doc_info.items():
  132. d = {}
  133. d["doc_id"] = k
  134. d["doc_name"] = v.get("doc_name")
  135. d["chunk_nums"] = len(v.get("chunk_list"))
  136. d["chunk_info_list"] = v.get("chunk_list")
  137. # d["chunk_len"] = len(v.get("chunk_id_list"))
  138. doc_list.append(d)
  139. return chunk_content, doc_list
  140. async def generate_event(self, chat_json, request):
  141. chat_id = ""
  142. try:
  143. logger.info(f"rag聊天的请求参数:{chat_json}")
  144. knowledge_id = chat_json.get("knowledgeIds")
  145. retriever_result_list = self.retriever_result(chat_json)
  146. logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  147. chunk_content, doc_list = self.parse_retriever_list(retriever_result_list)
  148. first = True
  149. async for event in self.generate_rag_response(chat_json, chunk_content):
  150. chat_id = event.get("id")
  151. if first:
  152. json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
  153. self.redis_client.set(chat_id, json.dumps(json_dict))
  154. # logger.info(f"返回检索出的切片信息:{json_dict}")
  155. # yield {"id": chat_id, "event": "json", "data": json_dict}
  156. first = False
  157. yield event
  158. # yield json.dumps(event, ensure_ascii=False)
  159. if await request.is_disconnected():
  160. logger.info(f"chat id:{chat_id}连接中断")
  161. yield {"id": chat_id, "event": "interrupted", "data": ""}
  162. return
  163. # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
  164. except Exception as e:
  165. logger.info(f"执行出错:{e}")
  166. yield {"id": chat_id, "event": "finish", "data": ""}
  167. return
  168. # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
  169. async def generate_relevant_query(self, query_json):
  170. messages = query_json.get("messages")
  171. model = query_json.get("model")
  172. query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
  173. for result in query_result:
  174. # result_json = json.loads(result)
  175. # logger.info(f"生成的问题:{result_json}")
  176. # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
  177. result = result.strip()
  178. logger.info(f"模型生成的问题:{result}")
  179. try:
  180. if "```json" in result:
  181. json_pattern = r'```json\s(.*?)```'
  182. matches = re.findall(json_pattern, result, re.DOTALL)
  183. result = matches[0]
  184. query_json = json.loads(result)
  185. except Exception as e:
  186. query_json = eval(result)
  187. query_list = query_json.get("问题")
  188. return {"code": 200, "data": query_list}
  189. async def search_slice(self):
  190. try:
  191. chunk_redis_str = self.redis_client.get(self.chat_id)
  192. chunk_json = json.loads(chunk_redis_str)
  193. chunk_json["code"] = 200
  194. except Exception as e:
  195. logger.error(f"查询redis报错:{e}")
  196. chunk_json = {
  197. "code": 500,
  198. "message": str(e)
  199. }
  200. return chunk_json