| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- from rag.db import MilvusOperate
- from rag.load_model import *
- from rag.llm import VllmApi
- from config import redis_config_dict
- import json
- import re
- import gc
- import redis
- from utils.get_logger import setup_logger
- import random
- import concurrent.futures
- logger = setup_logger(__name__)
- rerank_model_mapping = {
- "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
- "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
- }
- redis_host = redis_config_dict.get("host")
- redis_port = redis_config_dict.get("port")
- redis_db = redis_config_dict.get("db")
- class ChatRetrieverRag:
- def __init__(self, chat_json: dict = None, chat_id: str=None):
- self.chat_id = chat_id
- self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
- if not chat_id:
- self.vllm_client = VllmApi(chat_json)
- def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
- rerank_list = []
- for result in hybrid_search_result:
- rerank_list.append([query, result["content"]])
- tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
- # 重排序
- rerank_model.eval()
- inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
- inputs = {k: v.to(device) for k, v in inputs.items()}
- with torch.no_grad():
- logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
- logits = logits.detach().cpu()
- scores = torch.sigmoid(logits).tolist()
- # scores_list = scores.tolist()
- logger.info(f"重排序的得分:{scores}")
- sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
- sorted_scores, sorted_search = zip(*sorted_pairs)
- sorted_scores_list = list(sorted_scores)
- sorted_search_list = list(sorted_search)
- for score, search in zip(sorted_scores_list, sorted_search_list):
- search["rerank_score"] = score
- del inputs, logits
- gc.collect()
- torch.cuda.empty_cache()
- search_result = random.sample(sorted_search_list, top_k)
- return search_result
-
- def request_milvus(self, collection_name_list, rag_embedding_name, query, top, mode):
- search_multi_collection_list = []
- search_doc_id_to_knowledge_id_dict = {}
- def query_collection(collection_name):
- return collection_name, MilvusOperate(
- collection_name=collection_name,
- embedding_name=rag_embedding_name
- )._search(query, k=top, mode=mode)
-
- max_workers = min(10, len(collection_name_list))
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
- future_to_name = {
- executor.submit(query_collection, name): name
- for name in collection_name_list
- }
-
- for future in concurrent.futures.as_completed(future_to_name):
- collection_name = future_to_name[future]
- try:
- _, hybrid_search_result = future.result()
-
- for result in hybrid_search_result:
- doc_id = result.get("doc_id")
-
- if doc_id not in search_doc_id_to_knowledge_id_dict:
- search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
-
- search_multi_collection_list.extend(hybrid_search_result)
-
- except Exception as e:
- logger.error(f"查询集合 {collection_name} 时出错: {str(e)}")
- continue
- return search_multi_collection_list, search_doc_id_to_knowledge_id_dict
- def retriever_result(self, chat_json):
- mode_search_mode = {
- "embedding": "dense",
- "keyword": "sparse",
- "mixed": "hybrid",
- }
- collection_name_list = chat_json.get("knowledgeIds") # 如果指定多个knowledge传list
- top = chat_json.get("sliceCount", 5)
- retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
- rag_embedding_name = chat_json.get("embeddingId")
- mode_rag = retriever_info.get("recall_method")
- mode = mode_search_mode.get(mode_rag, "hybrid")
- rerank_embedding_name = retriever_info.get("rerank_model_name")
- rerank_status = retriever_info.get("rerank_status")
- query = chat_json.get("query")
- # search_multi_collection_list = []
- # search_doc_id_to_knowledge_id_dict = {}
- # for collection_name in collection_name_list:
- # hybrid_search_result = MilvusOperate(
- # collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
- # for result in hybrid_search_result:
- # doc_id = result.get("doc_id")
- # if doc_id not in search_doc_id_to_knowledge_id_dict:
- # search_doc_id_to_knowledge_id_dict[doc_id] = collection_name
- # search_multi_collection_list.extend(hybrid_search_result)
- search_multi_collection_list, search_doc_id_to_knowledge_id_dict = self.request_milvus(collection_name_list, rag_embedding_name, query, top, mode)
-
- # logger.info(f"根据{collection_name_list}检索到结果:{search_multi_collection_list}")
- if len(search_multi_collection_list) <= 0:
- rerank_result_list = []
- elif rerank_status:
- rerank_result_list = self.rerank_result(query, top, search_multi_collection_list, rerank_embedding_name)
- else:
- for result in search_multi_collection_list:
- result["rerank_score"] = 1
- # rerank_result_list = search_multi_collection_list
- rerank_result_list = random.sample(search_multi_collection_list, top)
- return rerank_result_list, search_doc_id_to_knowledge_id_dict
- async def generate_sync_response(self, chat_json):
- """非流式生成响应"""
- try:
- logger.info(f"非流式rag聊天的请求参数:{chat_json}")
-
- # 1. 检索相关文档
- retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json)
- logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
-
- # 2. 解析检索结果
- chunk_content, knowledge_info_dict = self.parse_retriever_list(
- retriever_result_list,
- search_doc_id_to_knowledge_id_dict
- )
-
- # 3. 构建提示词
- prompt = chat_json.get("prompt")
- query = chat_json.get("query")
- prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
- logger.info(f"请求的提示词:{prompt}")
-
- # 4. 调用LLM
- temperature = float(chat_json.get("temperature", 0.6))
- top_p = float(chat_json.get("topP", 0.7))
- max_token = chat_json.get("maxToken", 4096)
- enable_think = chat_json.get("enable_think", False)
- model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
-
- # 非流式调用
- chat_resp = ""
- async for chunk in self.vllm_client.chat(
- prompt,
- model,
- temperature=temperature,
- top_p=top_p,
- max_tokens=max_token,
- stream=False, # 设置为False非流式
- history=[],
- enable_think=enable_think
- ):
- chat_resp = chunk.choices[0].message.content # 非流式只返回一次完整内容
- # 记录检索数据
- chat_id = chunk.get("id")
- self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
-
- logger.info(f"LLM生成的回答:{chat_resp[:100]}...")
-
- # 5. 返回结果
- return {
- "code": 200,
- "data": {
- "answer": chat_resp,
- "chat_id": chat_id
- }
- }
-
- except Exception as e:
- logger.error(f"非流式RAG聊天出错:{e}")
- return {
- "code": 500,
- "message": str(e)
- }
- async def generate_rag_response(self, chat_json, chunk_content):
- # logger.info(f"rag聊天的请求参数:{chat_json}")
- # retriever_result_list = self.retriever_result(chat_json)
- # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- prompt = chat_json.get("prompt")
- query = chat_json.get("query")
- # chunk_content = ""
- # for retriever in retriever_result_list:
- # chunk_content += retriever["content"]
- prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
- # logger.info(f"请求的提示词:{prompt}")
- temperature = float(chat_json.get("temperature", 0.6))
- top_p = float(chat_json.get("topP", 0.7))
- max_token = chat_json.get("maxToken", 4096)
- enable_think = chat_json.get("enable_think", False)
- # 调用模型获取返回的结果
- model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
- # model = "Qwen3-Coder-30B-loft"
- chat_resp = ""
- async for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[], enable_think=enable_think):
- if chunk.get("event") == "add":
- chat_resp += chunk.get("data")
- chunk["data"] = chat_resp
- yield chunk
- else:
- yield chunk
-
- def parse_retriever_list(self, retriever_result_list, search_doc_id_to_knowledge_id_dict):
- doc_info = {}
- chunk_content = ""
- # 组织成每个doc对应的json格式
- for retriever in retriever_result_list:
- chunk_text = retriever["content"]
- chunk_content += chunk_text
- doc_id = retriever["doc_id"]
- chunk_id = retriever["chunk_id"]
- rerank_score = retriever["rerank_score"]
- doc_name = retriever["metadata"]["source"]
- if doc_id in doc_info:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text)
- }
- doc_info[doc_id]["chunk_list"].append(c_d)
- # doc_info[doc_id]["rerank_score"].append(rerank_score)
- else:
- c_d = {
- "chunk_id": chunk_id,
- "rerank_score": rerank_score,
- "chunk_len": len(chunk_text)
- }
- doc_info[doc_id] = {
- "doc_name": doc_name,
- "chunk_list": [c_d],
- }
- # doc_list = []
- knowledge_info_dict = {}
- """
- {
- "knowledge_id": [{}, {}],
- "knowledge_id"
- }
- """
- for k, v in doc_info.items():
- d = {}
- knowledge_id = search_doc_id_to_knowledge_id_dict.get(k)
- d["doc_id"] = k
- d["doc_name"] = v.get("doc_name")
- d["chunk_nums"] = len(v.get("chunk_list"))
- d["chunk_info_list"] = v.get("chunk_list")
- # d["chunk_len"] = len(v.get("chunk_id_list"))
- if knowledge_id in knowledge_info_dict:
- knowledge_info_dict[knowledge_id].append(d)
- else:
- knowledge_info_dict[knowledge_id] = [d]
- # doc_list.append(d)
- return chunk_content, knowledge_info_dict
- async def generate_event(self, chat_json, request):
- chat_id = ""
- try:
- logger.info(f"rag聊天的请求参数:{chat_json}")
- # knowledge_id = chat_json.get("knowledgeIds")
- retriever_result_list, search_doc_id_to_knowledge_id_dict = self.retriever_result(chat_json)
- logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
- chunk_content, knowledge_info_dict = self.parse_retriever_list(retriever_result_list, search_doc_id_to_knowledge_id_dict)
- first = True
- async for event in self.generate_rag_response(chat_json, chunk_content):
- chat_id = event.get("id")
- if first:
- #json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
- self.redis_client.set(chat_id, json.dumps(knowledge_info_dict))
- # logger.info(f"返回检索出的切片信息:{json_dict}")
- # yield {"id": chat_id, "event": "json", "data": json_dict}
- first = False
- yield event
- # yield json.dumps(event, ensure_ascii=False)
- if await request.is_disconnected():
- logger.info(f"chat id:{chat_id}连接中断")
- yield {"id": chat_id, "event": "interrupted", "data": ""}
- return
- # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
- except Exception as e:
- logger.info(f"执行出错:{e}")
- yield {"id": chat_id, "event": "finish", "data": ""}
- return
- # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
- async def query_redis_chunk_by_chat_id(self, chat_id):
- chunk_redis_str = self.redis_client.get(chat_id)
- chunk_json = json.loads(chunk_redis_str)
- logger.info(f"redis中存储的chunk 信息:{chunk_json}")
- #knowledge_id = chunk_json.get("knowledge_id")
- knowledge_id = list(chunk_json.keys())[0] if chunk_json else None
- chunk_list = []
- #for chunk in chunk_json["doc"]:
- # doc_chunk_list = chunk.get("chunk_info_list")
- # chunk_list.extend(doc_chunk_list)
- for knowledge_id_key, doc_list in chunk_json.items():
- for doc in doc_list:
- doc_chunk_list = doc.get("chunk_info_list")
- if doc_chunk_list:
- chunk_list.extend(doc_chunk_list)
- if len(chunk_list) > 3:
- random_chunk_list = random.sample(chunk_list, 3)
- else:
- random_chunk_list = chunk_list
- chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
- return knowledge_id, chunk_id_list
-
- async def generate_relevant_query(self, query_json):
- chat_id = query_json.get("chat_id")
- if chat_id:
- knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
- chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
- chunk_content = "\n".join(chunk_content_list)
- # generate_query_prompt = "请根据下面提供的信息,帮我生成"
- query_messages = query_json.get("messages")
- user_dict = query_messages.pop()
- messages = [{"role": "assistant", "content": chunk_content}, user_dict]
- else:
- messages = query_json.get("messages")
- model = query_json.get("model")
- # model = "Qwen3-Coder-30B-loft"
- query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
-
- async for result in query_result:
- # result_json = json.loads(result)
- # logger.info(f"生成的问题:{result_json}")
- # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
- result = result.strip()
- logger.info(f"模型生成的问题:{result}")
- try:
- if "```json" in result:
- json_pattern = r'```json\s(.*?)```'
- matches = re.findall(json_pattern, result, re.DOTALL)
- result = matches[0]
- query_json = json.loads(result)
- except Exception as e:
- query_json = eval(result)
- query_list = query_json.get("问题")
- return {"code": 200, "data": query_list}
- async def search_slice(self):
- try:
- chunk_redis_str = self.redis_client.get(self.chat_id)
- chunk_json = json.loads(chunk_redis_str)
- chunk_json["code"] = 200
- except Exception as e:
- logger.error(f"查询redis报错:{e}")
- chunk_json = {
- "code": 500,
- "message": str(e)
- }
- return chunk_json
|