|
|
@@ -7,6 +7,7 @@ import re
|
|
|
import gc
|
|
|
import redis
|
|
|
from utils.get_logger import setup_logger
|
|
|
+import random
|
|
|
logger = setup_logger(__name__)
|
|
|
|
|
|
rerank_model_mapping = {
|
|
|
@@ -184,8 +185,35 @@ class ChatRetrieverRag:
|
|
|
return
|
|
|
# yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
|
|
|
|
|
|
+ async def query_redis_chunk_by_chat_id(self, chat_id):
|
|
|
+ chunk_redis_str = self.redis_client.get(chat_id)
|
|
|
+ chunk_json = json.loads(chunk_redis_str)
|
|
|
+ logger.info(f"redis中存储的chunk 信息:{chunk_json}")
|
|
|
+
|
|
|
+ knowledge_id = chunk_json.get("knowledge_id")
|
|
|
+ chunk_list = []
|
|
|
+ for chunk in chunk_json["doc"]:
|
|
|
+ doc_chunk_list = chunk.get("chunk_info_list")
|
|
|
+ chunk_list.extend(doc_chunk_list)
|
|
|
+ if len(chunk_list) > 3:
|
|
|
+ random_chunk_list = random.sample(chunk_list, 3)
|
|
|
+ else:
|
|
|
+ random_chunk_list = chunk_list
|
|
|
+ chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
|
|
|
+ return knowledge_id, chunk_id_list
|
|
|
+
|
|
|
async def generate_relevant_query(self, query_json):
|
|
|
- messages = query_json.get("messages")
|
|
|
+ chat_id = query_json.get("chat_id")
|
|
|
+ if chat_id:
|
|
|
+ knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
|
|
|
+ chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
|
|
|
+ chunk_content = "\n".join(chunk_content_list)
|
|
|
+ # generate_query_prompt = "请根据下面提供的信息,帮我生成"
|
|
|
+ query_messages = query_json.get("messages")
|
|
|
+ user_dict = query_messages.pop()
|
|
|
+ messages = [{"role": "assistant", "content": chunk_content}, user_dict]
|
|
|
+ else:
|
|
|
+ messages = query_json.get("messages")
|
|
|
model = query_json.get("model")
|
|
|
|
|
|
query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
|