Bladeren bron

问题生成逻辑修改

weiyu 6 maanden geleden
bovenliggende
commit
fab5cc25f0
3 gewijzigde bestanden met toevoegingen van 59 en 1 verwijderingen
  1. 29 1
      rag/chat_message.py
  2. 14 0
      rag/db.py
  3. 16 0
      rag/vector_db/milvus_vector.py

+ 29 - 1
rag/chat_message.py

@@ -7,6 +7,7 @@ import re
 import gc
 import redis
 from utils.get_logger import setup_logger
+import random
 logger = setup_logger(__name__)
 
 rerank_model_mapping = {
@@ -184,8 +185,35 @@ class ChatRetrieverRag:
             return
             # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
 
+    async def query_redis_chunk_by_chat_id(self, chat_id):
+        chunk_redis_str = self.redis_client.get(chat_id)
+        chunk_json = json.loads(chunk_redis_str)
+        logger.info(f"redis中存储的chunk 信息:{chunk_json}")
+
+        knowledge_id = chunk_json.get("knowledge_id")
+        chunk_list = []
+        for chunk in chunk_json["doc"]:
+            doc_chunk_list = chunk.get("chunk_info_list")
+            chunk_list.extend(doc_chunk_list)
+        if len(chunk_list) > 3:
+            random_chunk_list = random.sample(chunk_list, 3)
+        else:
+            random_chunk_list = chunk_list
+        chunk_id_list = [chunk["chunk_id"] for chunk in random_chunk_list]
+        return knowledge_id, chunk_id_list
+    
     async def generate_relevant_query(self, query_json):
-        messages = query_json.get("messages")
+        chat_id = query_json.get("chat_id")
+        if chat_id:
+            knowledge_id, chunk_id_list = await self.query_redis_chunk_by_chat_id(chat_id)
+            chunk_content_list = MilvusOperate(collection_name=knowledge_id)._search_by_chunk_id_list(chunk_id_list)
+            chunk_content = "\n".join(chunk_content_list)
+            # generate_query_prompt = "请根据下面提供的信息,帮我生成"
+            query_messages = query_json.get("messages")
+            user_dict = query_messages.pop()
+            messages = [{"role": "assistant", "content": chunk_content}, user_dict]
+        else:
+            messages = query_json.get("messages")
         model = query_json.get("model")
 
         query_result = self.vllm_client.chat(model=model, stream=False, history=messages)

+ 14 - 0
rag/db.py

@@ -156,6 +156,20 @@ class MilvusOperate:
 
         return resp
     
+    def _search_by_chunk_id_list(self, chunk_id_list):
+        if self._has_collection():
+            query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list)
+        else:
+            query_result = []
+        logger.info(f"召回的切片列表查询切片信息:{query_result}")
+        
+        chunk_content_list = []
+        for chunk_dict in query_result:
+            chunk_content = chunk_dict.get("content")
+            chunk_content_list.append(chunk_content)
+
+        return chunk_content_list
+    
     
     def _search_by_key_word(self, search_json):
         if self._has_collection():

+ 16 - 0
rag/vector_db/milvus_vector.py

@@ -276,6 +276,22 @@ class HybridRetriever:
             logger.info(f"根据chunk id 查询出错:{e}")
             query_filter_results = [{"code": 500}]
         return query_filter_results
+    
+    def query_chunk_id_list(self, chunk_id_list):
+        # chunk id,查询切片
+        query_output_field = [
+            "content",
+            "doc_id",
+            "chunk_id",
+            # "metadata"
+        ]
+        query_expr = f"chunk_id in {chunk_id_list}"
+        try:
+            query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
+        except Exception as e:
+            logger.info(f"根据chunk id 查询出错:{e}")
+            query_filter_results = [{"code": 500}]
+        return query_filter_results
 
     def update_data(self, chunk_id, chunk):
         # 根据chunk id查询对应的信息,