chat_message.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. from rag.db import MilvusOperate
  2. from rag.load_model import *
  3. from rag.llm import VllmApi
  4. from config import redis_config_dict
  5. import json
  6. import re
  7. import gc
  8. import redis
  9. from utils.get_logger import setup_logger
  10. import random
  11. logger = setup_logger(__name__)
  12. rerank_model_mapping = {
  13. "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
  14. "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
  15. }
  16. redis_host = redis_config_dict.get("host")
  17. redis_port = redis_config_dict.get("port")
  18. redis_db = redis_config_dict.get("db")
  19. class ChatRetrieverRag:
  20. def __init__(self, chat_json: dict = None, chat_id: str=None):
  21. self.chat_id = chat_id
  22. self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
  23. if not chat_id:
  24. self.vllm_client = VllmApi(chat_json)
  25. def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
  26. rerank_list = []
  27. for result in hybrid_search_result:
  28. rerank_list.append([query, result["content"]])
  29. tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
  30. # 重排序
  31. rerank_model.eval()
  32. inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
  33. inputs = {k: v.to(device) for k, v in inputs.items()}
  34. with torch.no_grad():
  35. logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
  36. logits = logits.detach().cpu()
  37. scores = torch.sigmoid(logits).tolist()
  38. # scores_list = scores.tolist()
  39. logger.info(f"重排序的得分:{scores}")
  40. sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
  41. sorted_scores, sorted_search = zip(*sorted_pairs)
  42. sorted_scores_list = list(sorted_scores)
  43. sorted_search_list = list(sorted_search)
  44. for score, search in zip(sorted_scores_list, sorted_search_list):
  45. search["rerank_score"] = score
  46. del inputs, logits
  47. gc.collect()
  48. torch.cuda.empty_cache()
  49. return sorted_search_list
  50. def retriever_result(self, chat_json):
  51. mode_search_mode = {
  52. "embedding": "dense",
  53. "keyword": "sparse",
  54. "mixed": "hybrid",
  55. }
  56. collection_name = chat_json.get("knowledgeIds") #
  57. top = chat_json.get("sliceCount", 5)
  58. retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
  59. rag_embedding_name = chat_json.get("embeddingId")
  60. mode_rag = retriever_info.get("recall_method")
  61. mode = mode_search_mode.get(mode_rag, "hybrid")
  62. rerank_embedding_name = retriever_info.get("rerank_model_name")
  63. rerank_status = retriever_info.get("rerank_status")
  64. query = chat_json.get("query")
  65. hybrid_search_result = MilvusOperate(
  66. collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
  67. logger.info(f"根据{collection_name}检索到结果:{hybrid_search_result}")
  68. if len(hybrid_search_result) <= 0:
  69. rerank_result_list = []
  70. elif rerank_status:
  71. rerank_result_list = self.rerank_result(query, top, hybrid_search_result, rerank_embedding_name)
  72. else:
  73. for result in hybrid_search_result:
  74. result["rerank_score"] = 1
  75. rerank_result_list = hybrid_search_result
  76. return rerank_result_list
  77. async def generate_rag_response(self, chat_json, chunk_content):
  78. # logger.info(f"rag聊天的请求参数:{chat_json}")
  79. # retriever_result_list = self.retriever_result(chat_json)
  80. # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  81. prompt = chat_json.get("prompt")
  82. query = chat_json.get("query")
  83. # chunk_content = ""
  84. # for retriever in retriever_result_list:
  85. # chunk_content += retriever["content"]
  86. prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
  87. logger.info(f"请求的提示词:{prompt}")
  88. temperature = float(chat_json.get("temperature", 0.6))
  89. top_p = float(chat_json.get("topP", 0.7))
  90. max_token = chat_json.get("maxToken", 4096)
  91. # 调用模型获取返回的结果
  92. model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
  93. # model = "DeepSeek-R1-Distill-Qwen-14B"
  94. chat_resp = ""
  95. for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[]):
  96. if chunk.get("event") == "add":
  97. chat_resp += chunk.get("data")
  98. chunk["data"] = chat_resp
  99. yield chunk
  100. else:
  101. yield chunk
  102. def parse_retriever_list(self, retriever_result_list):
  103. doc_info = {}
  104. chunk_content = ""
  105. # 组织成每个doc对应的json格式
  106. for retriever in retriever_result_list:
  107. chunk_text = retriever["content"]
  108. chunk_content += chunk_text
  109. doc_id = retriever["doc_id"]
  110. chunk_id = retriever["chunk_id"]
  111. rerank_score = retriever["rerank_score"]
  112. doc_name = retriever["metadata"]["source"]
  113. if doc_id in doc_info:
  114. c_d = {
  115. "chunk_id": chunk_id,
  116. "rerank_score": rerank_score,
  117. "chunk_len": len(chunk_text)
  118. }
  119. doc_info[doc_id]["chunk_list"].append(c_d)
  120. # doc_info[doc_id]["rerank_score"].append(rerank_score)
  121. else:
  122. c_d = {
  123. "chunk_id": chunk_id,
  124. "rerank_score": rerank_score,
  125. "chunk_len": len(chunk_text)
  126. }
  127. doc_info[doc_id] = {
  128. "doc_name": doc_name,
  129. "chunk_list": [c_d],
  130. }
  131. doc_list = []
  132. for k, v in doc_info.items():
  133. d = {}
  134. d["doc_id"] = k
  135. d["doc_name"] = v.get("doc_name")
  136. d["chunk_nums"] = len(v.get("chunk_list"))
  137. d["chunk_info_list"] = v.get("chunk_list")
  138. # d["chunk_len"] = len(v.get("chunk_id_list"))
  139. doc_list.append(d)
  140. return chunk_content, doc_list
  141. async def generate_event(self, chat_json, request):
  142. chat_id = ""
  143. try:
  144. logger.info(f"rag聊天的请求参数:{chat_json}")
  145. knowledge_id = chat_json.get("knowledgeIds")
  146. retriever_result_list = self.retriever_result(chat_json)
  147. logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
  148. chunk_content, doc_list = self.parse_retriever_list(retriever_result_list)
  149. first = True
  150. async for event in self.generate_rag_response(chat_json, chunk_content):
  151. chat_id = event.get("id")
  152. if first:
  153. json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
  154. self.redis_client.set(chat_id, json.dumps(json_dict))
  155. # logger.info(f"返回检索出的切片信息:{json_dict}")
  156. # yield {"id": chat_id, "event": "json", "data": json_dict}
  157. first = False
  158. yield event
  159. # yield json.dumps(event, ensure_ascii=False)
  160. if await request.is_disconnected():
  161. logger.info(f"chat id:{chat_id}连接中断")
  162. yield {"id": chat_id, "event": "interrupted", "data": ""}
  163. return
  164. # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
  165. except Exception as e:
  166. logger.info(f"执行出错:{e}")
  167. yield {"id": chat_id, "event": "finish", "data": ""}
  168. return
  169. # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
  170. async def query_redis_chunk_by_chat_id(self, chat_id):
  171. chunk_redis_str = self.redis_client.get(chat_id)
  172. chunk_json = json.loads(chunk_redis_str)
  173. logger.info(f"redis中存储的chunk 信息:{chunk_json}")
  174. knowledge_id = chunk_json.get("knowledge_id")
  175. chunk_list = []
  176. for chunk in chunk_json["doc"]:
  177. doc_chunk_list = chunk.get("chunk_info_list")
  178. chunk_list.extend(doc_chunk_list)
  179. if len(chunk_list) > 3:
  180. random_chunk_list = random.sample(chunk_list, 3)
  181. else:
  182. random_chunk_list = chunk_list
  183. chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
  184. return knowledge_id, chunk_id_list
  185. async def generate_relevant_query(self, query_json):
  186. chat_id = query_json.get("chat_id")
  187. if chat_id:
  188. knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
  189. chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
  190. chunk_content = "\n".join(chunk_content_list)
  191. # generate_query_prompt = "请根据下面提供的信息,帮我生成"
  192. query_messages = query_json.get("messages")
  193. user_dict = query_messages.pop()
  194. messages = [{"role": "assistant", "content": chunk_content}, user_dict]
  195. else:
  196. messages = query_json.get("messages")
  197. model = query_json.get("model")
  198. query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
  199. for result in query_result:
  200. # result_json = json.loads(result)
  201. # logger.info(f"生成的问题:{result_json}")
  202. # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
  203. result = result.strip()
  204. logger.info(f"模型生成的问题:{result}")
  205. try:
  206. if "```json" in result:
  207. json_pattern = r'```json\s(.*?)```'
  208. matches = re.findall(json_pattern, result, re.DOTALL)
  209. result = matches[0]
  210. query_json = json.loads(result)
  211. except Exception as e:
  212. query_json = eval(result)
  213. query_list = query_json.get("问题")
  214. return {"code": 200, "data": query_list}
  215. async def search_slice(self):
  216. try:
  217. chunk_redis_str = self.redis_client.get(self.chat_id)
  218. chunk_json = json.loads(chunk_redis_str)
  219. chunk_json["code"] = 200
  220. except Exception as e:
  221. logger.error(f"查询redis报错:{e}")
  222. chunk_json = {
  223. "code": 500,
  224. "message": str(e)
  225. }
  226. return chunk_json