from rag.db import MilvusOperate from rag.load_model import * from rag.llm import VllmApi from config import redis_config_dict import json import re import gc import redis from utils.get_logger import setup_logger logger = setup_logger(__name__) rerank_model_mapping = { "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model), "rerank": (bce_rerank_tokenizer, bce_rerank_base_model) } redis_host = redis_config_dict.get("host") redis_port = redis_config_dict.get("port") redis_db = redis_config_dict.get("db") class ChatRetrieverRag: def __init__(self, chat_json: dict = None, chat_id: str=None): self.chat_id = chat_id self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db) if not chat_id: self.vllm_client = VllmApi(chat_json) def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name): rerank_list = [] for result in hybrid_search_result: rerank_list.append([query, result["content"]]) tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name) # 重排序 rerank_model.eval() inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float() logits = logits.detach().cpu() scores = torch.sigmoid(logits).tolist() # scores_list = scores.tolist() logger.info(f"重排序的得分:{scores}") sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True) sorted_scores, sorted_search = zip(*sorted_pairs) sorted_scores_list = list(sorted_scores) sorted_search_list = list(sorted_search) for score, search in zip(sorted_scores_list, sorted_search_list): search["rerank_score"] = score del inputs, logits gc.collect() torch.cuda.empty_cache() return sorted_search_list def retriever_result(self, chat_json): mode_search_mode = { "embedding": "dense", "keyword": "sparse", "mixed": "hybrid", } collection_name = chat_json.get("knowledgeIds") # top = chat_json.get("sliceCount", 5) retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}")) rag_embedding_name = chat_json.get("embeddingId") mode_rag = retriever_info.get("recall_method") mode = mode_search_mode.get(mode_rag, "hybrid") rerank_embedding_name = retriever_info.get("rerank_model_name") rerank_status = retriever_info.get("rerank_status") query = chat_json.get("query") hybrid_search_result = MilvusOperate( collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode) logger.info(f"根据{collection_name}检索到结果:{hybrid_search_result}") if len(hybrid_search_result) <= 0: rerank_result_list = [] elif rerank_status: rerank_result_list = self.rerank_result(query, top, hybrid_search_result, rerank_embedding_name) else: for result in hybrid_search_result: result["rerank_score"] = 1 rerank_result_list = hybrid_search_result return rerank_result_list async def generate_rag_response(self, chat_json, chunk_content): # logger.info(f"rag聊天的请求参数:{chat_json}") # retriever_result_list = self.retriever_result(chat_json) # logger.info(f"向量库中获取的最终结果:{retriever_result_list}") prompt = chat_json.get("prompt") query = chat_json.get("query") # chunk_content = "" # for retriever in retriever_result_list: # chunk_content += retriever["content"] prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query) logger.info(f"请求的提示词:{prompt}") temperature = float(chat_json.get("temperature", 0.6)) top_p = float(chat_json.get("topP", 0.7)) max_token = chat_json.get("maxToken", 4096) # 调用模型获取返回的结果 model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B") # model = "DeepSeek-R1-Distill-Qwen-14B" chat_resp = "" for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[]): if chunk.get("event") == "add": chat_resp += chunk.get("data") chunk["data"] = chat_resp yield chunk else: yield chunk def parse_retriever_list(self, retriever_result_list): doc_info = {} chunk_content = "" # 组织成每个doc对应的json格式 for retriever in retriever_result_list: chunk_text = retriever["content"] chunk_content += chunk_text doc_id = retriever["doc_id"] chunk_id = retriever["chunk_id"] rerank_score = retriever["rerank_score"] doc_name = retriever["metadata"]["source"] if doc_id in doc_info: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text) } doc_info[doc_id]["chunk_list"].append(c_d) # doc_info[doc_id]["rerank_score"].append(rerank_score) else: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text) } doc_info[doc_id] = { "doc_name": doc_name, "chunk_list": [c_d], } doc_list = [] for k, v in doc_info.items(): d = {} d["doc_id"] = k d["doc_name"] = v.get("doc_name") d["chunk_nums"] = len(v.get("chunk_list")) d["chunk_info_list"] = v.get("chunk_list") # d["chunk_len"] = len(v.get("chunk_id_list")) doc_list.append(d) return chunk_content, doc_list async def generate_event(self, chat_json, request): chat_id = "" try: logger.info(f"rag聊天的请求参数:{chat_json}") knowledge_id = chat_json.get("knowledgeIds") retriever_result_list = self.retriever_result(chat_json) logger.info(f"向量库中获取的最终结果:{retriever_result_list}") chunk_content, doc_list = self.parse_retriever_list(retriever_result_list) first = True async for event in self.generate_rag_response(chat_json, chunk_content): chat_id = event.get("id") if first: json_dict = {"knowledge_id": knowledge_id, "doc": doc_list} self.redis_client.set(chat_id, json.dumps(json_dict)) # logger.info(f"返回检索出的切片信息:{json_dict}") # yield {"id": chat_id, "event": "json", "data": json_dict} first = False yield event # yield json.dumps(event, ensure_ascii=False) if await request.is_disconnected(): logger.info(f"chat id:{chat_id}连接中断") yield {"id": chat_id, "event": "interrupted", "data": ""} return # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False) except Exception as e: logger.info(f"执行出错:{e}") yield {"id": chat_id, "event": "finish", "data": ""} return # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False) async def generate_relevant_query(self, query_json): messages = query_json.get("messages") model = query_json.get("model") query_result = self.vllm_client.chat(model=model, stream=False, history=messages) for result in query_result: # result_json = json.loads(result) # logger.info(f"生成的问题:{result_json}") # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip() result = result.strip() logger.info(f"模型生成的问题:{result}") try: if "```json" in result: json_pattern = r'```json\s(.*?)```' matches = re.findall(json_pattern, result, re.DOTALL) result = matches[0] query_json = json.loads(result) except Exception as e: query_json = eval(result) query_list = query_json.get("问题") return {"code": 200, "data": query_list} async def search_slice(self): try: chunk_redis_str = self.redis_client.get(self.chat_id) chunk_json = json.loads(chunk_redis_str) chunk_json["code"] = 200 except Exception as e: logger.error(f"查询redis报错:{e}") chunk_json = { "code": 500, "message": str(e) } return chunk_json