| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251 |
- from rag.db import MilvusOperate
- from rag.load_model import *
- from rag.llm import VllmApi
- from config import redis_config_dict
- import json
- import re
- import gc
- import redis
- from utils.get_logger import setup_logger
- import random
- logger = setup_logger(__name__)
- rerank_model_mapping = {
- "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
- "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
- }
- redis_host = redis_config_dict.get("host")
- redis_port = redis_config_dict.get("port")
- redis_db = redis_config_dict.get("db")
- class ChatRetrieverRag:
- def __init__(self, chat_json: dict = None, chat_id: str=None):
- self.chat_id = chat_id
- self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
- if not chat_id:
- self.vllm_client = VllmApi(chat_json)
- def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
- rerank_list = []
- for result in hybrid_search_result:
- rerank_list.append([query, result["content"]])
- tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
- # 重排序
- rerank_model.eval()
- inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
- inputs = {k: v.to(device) for k, v in inputs.items()}
- with torch.no_grad():
- logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
- logits = logits.detach().cpu()
- scores = torch.sigmoid(logits).tolist()
- # scores_list = scores.tolist()
- logger.info(f"重排序的得分:{scores}")
- sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
- sorted_scores, sorted_search = zip(*sorted_pairs)
- sorted_scores_list = list(sorted_scores)
- sorted_search_list = list(sorted_search)
- for score, search in zip(sorted_scores_list, sorted_search_list):
- search["rerank_score"] = score
- del inputs, logits
- gc.collect()
- torch.cuda.empty_cache()
- return sorted_search_list
-
- def retriever_result(self, chat_json):
- mode_search_mode = {
- "embedding": "dense",
- "keyword": "sparse",
- "mixed": "hybrid",
- }
- collection_name = chat_json.get("knowledgeIds") #
- top = chat_json.get("sliceCount", 5)
- retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
- rag_embedding_name = chat_json.get("embeddingId")
- mode_rag = retriever_info.get("recall_method")
- mode = mode_search_mode.get(mode_rag, "hybrid")
- rerank_embedding_name = retriever_info.get("rerank_model_name")
- rerank_status = retriever_info.get("rerank_status")
- query = chat_json.get("query")
- hybrid_search_result = MilvusOperate(
- collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
- logger.info(f"根据{collection_name}检索到结果:{hybrid_search_result}")
- if len(hybrid_search_result) <= 0:
- rerank_result_list = []
- elif rerank_status:
- rerank_result_list = self.rerank_result(query, top, hybrid_search_result, rerank_embedding_name)
- else:
- for result in hybrid_search_result:
- result["rerank_score"] = 1
- rerank_result_list = hybrid_search_result
- return rerank_result_list
- async def generate_rag_response(self, chat_json, chunk_content):
- # logger.info(f"rag聊天的请求参数:{chat_json}")
- # retriever_result_list = self.retriever_result(chat_json)
- # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- prompt = chat_json.get("prompt")
- query = chat_json.get("query")
- # chunk_content = ""
- # for retriever in retriever_result_list:
- # chunk_content += retriever["content"]
- prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
- logger.info(f"请求的提示词:{prompt}")
- temperature = float(chat_json.get("temperature", 0.6))
- top_p = float(chat_json.get("topP", 0.7))
- max_token = chat_json.get("maxToken", 4096)
- # 调用模型获取返回的结果
- model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
- # model = "DeepSeek-R1-Distill-Qwen-14B"
- chat_resp = ""
- for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[]):
- if chunk.get("event") == "add":
- chat_resp += chunk.get("data")
- chunk["data"] = chat_resp
- yield chunk
- else:
- yield chunk
-
- def parse_retriever_list(self, retriever_result_list):
- doc_info = {}
- chunk_content = ""
- # 组织成每个doc对应的json格式
- for retriever in retriever_result_list:
- chunk_text = retriever["content"]
- chunk_content += chunk_text
- doc_id = retriever["doc_id"]
- chunk_id = retriever["chunk_id"]
- rerank_score = retriever["rerank_score"]
- doc_name = retriever["metadata"]["source"]
- if doc_id in doc_info:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text)
- }
- doc_info[doc_id]["chunk_list"].append(c_d)
- # doc_info[doc_id]["rerank_score"].append(rerank_score)
- else:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text)
- }
- doc_info[doc_id] = {
- "doc_name": doc_name,
- "chunk_list": [c_d],
- }
- doc_list = []
- for k, v in doc_info.items():
- d = {}
- d["doc_id"] = k
- d["doc_name"] = v.get("doc_name")
- d["chunk_nums"] = len(v.get("chunk_list"))
- d["chunk_info_list"] = v.get("chunk_list")
- # d["chunk_len"] = len(v.get("chunk_id_list"))
- doc_list.append(d)
- return chunk_content, doc_list
- async def generate_event(self, chat_json, request):
- chat_id = ""
- try:
- logger.info(f"rag聊天的请求参数:{chat_json}")
- knowledge_id = chat_json.get("knowledgeIds")
- retriever_result_list = self.retriever_result(chat_json)
- logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- chunk_content, doc_list = self.parse_retriever_list(retriever_result_list)
- first = True
- async for event in self.generate_rag_response(chat_json, chunk_content):
- chat_id = event.get("id")
- if first:
- json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
- self.redis_client.set(chat_id, json.dumps(json_dict))
- # logger.info(f"返回检索出的切片信息:{json_dict}")
- # yield {"id": chat_id, "event": "json", "data": json_dict}
- first = False
- yield event
- # yield json.dumps(event, ensure_ascii=False)
- if await request.is_disconnected():
- logger.info(f"chat id:{chat_id}连接中断")
- yield {"id": chat_id, "event": "interrupted", "data": ""}
- return
- # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
- except Exception as e:
- logger.info(f"执行出错:{e}")
- yield {"id": chat_id, "event": "finish", "data": ""}
- return
- # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
- async def query_redis_chunk_by_chat_id(self, chat_id):
- chunk_redis_str = self.redis_client.get(chat_id)
- chunk_json = json.loads(chunk_redis_str)
- logger.info(f"redis中存储的chunk 信息:{chunk_json}")
- knowledge_id = chunk_json.get("knowledge_id")
- chunk_list = []
- for chunk in chunk_json["doc"]:
- doc_chunk_list = chunk.get("chunk_info_list")
- chunk_list.extend(doc_chunk_list)
- if len(chunk_list) > 3:
- random_chunk_list = random.sample(chunk_list, 3)
- else:
- random_chunk_list = chunk_list
- chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
- return knowledge_id, chunk_id_list
-
- async def generate_relevant_query(self, query_json):
- chat_id = query_json.get("chat_id")
- if chat_id:
- knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
- chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
- chunk_content = "\n".join(chunk_content_list)
- # generate_query_prompt = "请根据下面提供的信息,帮我生成"
- query_messages = query_json.get("messages")
- user_dict = query_messages.pop()
- messages = [{"role": "assistant", "content": chunk_content}, user_dict]
- else:
- messages = query_json.get("messages")
- model = query_json.get("model")
- query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
-
- for result in query_result:
- # result_json = json.loads(result)
- # logger.info(f"生成的问题:{result_json}")
- # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
- result = result.strip()
- logger.info(f"模型生成的问题:{result}")
- try:
- if "```json" in result:
- json_pattern = r'```json\s(.*?)```'
- matches = re.findall(json_pattern, result, re.DOTALL)
- result = matches[0]
- query_json = json.loads(result)
- except Exception as e:
- query_json = eval(result)
- query_list = query_json.get("问题")
- return {"code": 200, "data": query_list}
- async def search_slice(self):
- try:
- chunk_redis_str = self.redis_client.get(self.chat_id)
- chunk_json = json.loads(chunk_redis_str)
- chunk_json["code"] = 200
- except Exception as e:
- logger.error(f"查询redis报错:{e}")
- chunk_json = {
- "code": 500,
- "message": str(e)
- }
- return chunk_json
|