from rag.db import MilvusOperate from rag.load_model import * from rag.llm import VllmApi from config import redis_config_dict import json import re import gc import redis from utils.get_logger import setup_logger import random import concurrent.futures logger = setup_logger(__name__) rerank_model_mapping = { "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model), "rerank": (bce_rerank_tokenizer, bce_rerank_base_model) } redis_host = redis_config_dict.get("host") redis_port = redis_config_dict.get("port") redis_db = redis_config_dict.get("db") class ChatRetrieverRag: def __init__(self, chat_json: dict = None, chat_id: str=None): self.chat_id = chat_id self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db) if not chat_id: self.vllm_client = VllmApi(chat_json) def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name): rerank_list = [] for result in hybrid_search_result: rerank_list.append([query, result["content"]]) tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name) # 重排序 rerank_model.eval() inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float() logits = logits.detach().cpu() scores = torch.sigmoid(logits).tolist() # scores_list = scores.tolist() logger.info(f"重排序的得分:{scores}") sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True) sorted_scores, sorted_search = zip(*sorted_pairs) sorted_scores_list = list(sorted_scores) sorted_search_list = list(sorted_search) for score, search in zip(sorted_scores_list, sorted_search_list): search["rerank_score"] = score del inputs, logits gc.collect() torch.cuda.empty_cache() search_result = random.sample(sorted_search_list, top_k) return search_result def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode): search_multi_collection_list = [] search_doc_id_to_knowledge_id_dict = {} def query_collection(collection_name): return collection_name, MilvusOperate( collection_name=collection_name, embedding_name=rag_embedding_name )._search(query, k=top, mode=mode) max_workers = min(10, len(collection_name_list)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_name = { executor.submit(query_collection, name): name for name in collection_name_list } for future in concurrent.futures.as_completed(future_to_name): collection_name = future_to_name[future] try: _, hybrid_search_result = future.result() for result in hybrid_search_result: doc_id = result.get("doc_id") if doc_id not in search_doc_id_to_knowledge_id_dict: search_doc_id_to_knowledge_id_dict[doc_id] = collection_name search_multi_collection_list.extend(hybrid_search_result) except Exception as e: logger.error(f"查询集合 {collection_name} 时出错: {str(e)}") continue return search_multi_collection_list, search_doc_id_to_knowledge_id_dict def retriever_result(self, chat_json): mode_search_mode = { "embedding": "dense", "keyword": "sparse", "mixed": "hybrid", } collection_name_list = chat_json.get("knowledgeIds") # 如果指定多个knowledge传list top = chat_json.get("sliceCount", 5) retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}")) rag_embedding_name = chat_json.get("embeddingId") mode_rag = retriever_info.get("recall_method") mode = mode_search_mode.get(mode_rag, "hybrid") rerank_embedding_name = retriever_info.get("rerank_model_name") rerank_status = retriever_info.get("rerank_status") query = chat_json.get("query") # search_multi_collection_list = [] # search_doc_id_to_knowledge_id_dict = {} # for collection_name in collection_name_list: # hybrid_search_result = MilvusOperate( # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode) # for result in hybrid_search_result: # doc_id = result.get("doc_id") # if doc_id not in search_doc_id_to_knowledge_id_dict: # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name # search_multi_collection_list.extend(hybrid_search_result) search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus(collection_name_list, rag_embedding_name, query, top, mode) # logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}") if len(search_multi_collection_list) <= 0: rerank_result_list = [] elif rerank_status: rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name) else: for result in search_multi_collection_list: result["rerank_score"] = 1 # rerank_result_list = search_multi_collection_list rerank_result_list = random.sample(search_multi_collection_list, top) return rerank_result_list, search_doc_id_to_knowledge_id_dict async def generate_sync_response(self, chat_json): """非流式生成响应""" try: logger.info(f"非流式rag聊天的请求参数:{chat_json}") # 1. 检索相关文档 retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json) logger.info(f"向量库中获取的最终结果:{retriever_result_list}") # 2. 解析检索结果 chunk_content, knowledge_info_dict = self.parse_retriever_list( retriever_result_list, search_doc_id_to_knowledge_id_dict ) # 3. 构建提示词 prompt = chat_json.get("prompt") query = chat_json.get("query") prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query) logger.info(f"请求的提示词:{prompt}") # 4. 调用LLM temperature = float(chat_json.get("temperature", 0.6)) top_p = float(chat_json.get("topP", 0.7)) max_token = chat_json.get("maxToken", 4096) enable_think = chat_json.get("enable_think", False) model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B") # 非流式调用 chat_resp = "" async for chunk in self.vllm_client.chat( prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=False, # 设置为False非流式 history=[], enable_think=enable_think ): chat_resp = chunk.choices[0].message.content # 非流式只返回一次完整内容 # 记录检索数据 chat_id = chunk.get("id") self.redis_client.set(chat_id, json.dumps(knowledge_info_dict)) logger.info(f"LLM生成的回答:{chat_resp[:100]}...") # 5. 返回结果 return { "code": 200, "data": { "answer": chat_resp, "chat_id": chat_id } } except Exception as e: logger.error(f"非流式RAG聊天出错:{e}") return { "code": 500, "message": str(e) } async def generate_rag_response(self, chat_json, chunk_content): # logger.info(f"rag聊天的请求参数:{chat_json}") # retriever_result_list = self.retriever_result(chat_json) # logger.info(f"向量库中获取的最终结果:{retriever_result_list}") prompt = chat_json.get("prompt") query = chat_json.get("query") # chunk_content = "" # for retriever in retriever_result_list: # chunk_content += retriever["content"] prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query) # logger.info(f"请求的提示词:{prompt}") temperature = float(chat_json.get("temperature", 0.6)) top_p = float(chat_json.get("topP", 0.7)) max_token = chat_json.get("maxToken", 4096) enable_think = chat_json.get("enable_think", False) # 调用模型获取返回的结果 model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B") # model = "Qwen3-Coder-30B-loft" chat_resp = "" async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[], enable_think=enable_think): if chunk.get("event") == "add": chat_resp += chunk.get("data") chunk["data"] = chat_resp yield chunk else: yield chunk def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict): doc_info = {} chunk_content = "" # 组织成每个doc对应的json格式 for retriever in retriever_result_list: chunk_text = retriever["content"] chunk_content += chunk_text doc_id = retriever["doc_id"] chunk_id = retriever["chunk_id"] rerank_score = retriever["rerank_score"] doc_name = retriever["metadata"]["source"] if doc_id in doc_info: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text) } doc_info[doc_id]["chunk_list"].append(c_d) # doc_info[doc_id]["rerank_score"].append(rerank_score) else: c_d = { "chunk_id": chunk_id, "rerank_score": rerank_score, "chunk_len": len(chunk_text) } doc_info[doc_id] = { "doc_name": doc_name, "chunk_list": [c_d], } # doc_list = [] knowledge_info_dict = {} """ { "knowledge_id": [{}, {}], "knowledge_id" } """ for k, v in doc_info.items(): d = {} knowledge_id = search_doc_id_to_knowledge_id_dict.get(k) d["doc_id"] = k d["doc_name"] = v.get("doc_name") d["chunk_nums"] = len(v.get("chunk_list")) d["chunk_info_list"] = v.get("chunk_list") # d["chunk_len"] = len(v.get("chunk_id_list")) if knowledge_id in knowledge_info_dict: knowledge_info_dict[knowledge_id].append(d) else: knowledge_info_dict[knowledge_id] = [d] # doc_list.append(d) return chunk_content, knowledge_info_dict async def generate_event(self, chat_json, request): chat_id = "" try: logger.info(f"rag聊天的请求参数:{chat_json}") # knowledge_id = chat_json.get("knowledgeIds") retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json) logger.info(f"向量库中获取的最终结果:{retriever_result_list}") chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict) first = True async for event in self.generate_rag_response(chat_json, chunk_content): chat_id = event.get("id") if first: #json_dict = {"knowledge_id": knowledge_id, "doc": doc_list} self.redis_client.set(chat_id, json.dumps(knowledge_info_dict)) # logger.info(f"返回检索出的切片信息:{json_dict}") # yield {"id": chat_id, "event": "json", "data": json_dict} first = False yield event # yield json.dumps(event, ensure_ascii=False) if await request.is_disconnected(): logger.info(f"chat id:{chat_id}连接中断") yield {"id": chat_id, "event": "interrupted", "data": ""} return # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False) except Exception as e: logger.info(f"执行出错:{e}") yield {"id": chat_id, "event": "finish", "data": ""} return # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False) async def query_redis_chunk_by_chat_id(self, chat_id): chunk_redis_str = self.redis_client.get(chat_id) chunk_json = json.loads(chunk_redis_str) logger.info(f"redis中存储的chunk 信息:{chunk_json}") #knowledge_id = chunk_json.get("knowledge_id") knowledge_id = list(chunk_json.keys())[0] if chunk_json else None chunk_list = [] #for chunk in chunk_json["doc"]: # doc_chunk_list = chunk.get("chunk_info_list") # chunk_list.extend(doc_chunk_list) for knowledge_id_key, doc_list in chunk_json.items(): for doc in doc_list: doc_chunk_list = doc.get("chunk_info_list") if doc_chunk_list: chunk_list.extend(doc_chunk_list) if len(chunk_list) > 3: random_chunk_list = random.sample(chunk_list, 3) else: random_chunk_list = chunk_list chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list] return knowledge_id, chunk_id_list async def generate_relevant_query(self, query_json): chat_id = query_json.get("chat_id") if chat_id: knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id) chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list) chunk_content = "\n".join(chunk_content_list) # generate_query_prompt = "请根据下面提供的信息,帮我生成" query_messages = query_json.get("messages") user_dict = query_messages.pop() messages = [{"role": "assistant", "content": chunk_content}, user_dict] else: messages = query_json.get("messages") model = query_json.get("model") # model = "Qwen3-Coder-30B-loft" query_result = self.vllm_client.chat(model=model, stream=False, history=messages) async for result in query_result: # result_json = json.loads(result) # logger.info(f"生成的问题:{result_json}") # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip() result = result.strip() logger.info(f"模型生成的问题:{result}") try: if "```json" in result: json_pattern = r'```json\s(.*?)```' matches = re.findall(json_pattern, result, re.DOTALL) result = matches[0] query_json = json.loads(result) except Exception as e: query_json = eval(result) query_list = query_json.get("问题") return {"code": 200, "data": query_list} async def search_slice(self): try: chunk_redis_str = self.redis_client.get(self.chat_id) chunk_json = json.loads(chunk_redis_str) chunk_json["code"] = 200 except Exception as e: logger.error(f"查询redis报错:{e}") chunk_json = { "code": 500, "message": str(e) } return chunk_json