weiyu 6 ヶ月 前
コミット
e0706614c0
52 ファイル変更2821 行追加0 行削除
  1. 77 0
      config.py
  2. 0 0
      network_search/search.py
  3. 0 0
      rag/__init__.py
  4. BIN
      rag/__pycache__/__init__.cpython-310.pyc
  5. BIN
      rag/__pycache__/__init__.cpython-311.pyc
  6. BIN
      rag/__pycache__/chat_message.cpython-310.pyc
  7. BIN
      rag/__pycache__/chat_message.cpython-311.pyc
  8. BIN
      rag/__pycache__/db.cpython-310.pyc
  9. BIN
      rag/__pycache__/db.cpython-311.pyc
  10. BIN
      rag/__pycache__/documents_process.cpython-310.pyc
  11. BIN
      rag/__pycache__/file_process.cpython-310.pyc
  12. BIN
      rag/__pycache__/file_process.cpython-311.pyc
  13. BIN
      rag/__pycache__/llm.cpython-310.pyc
  14. BIN
      rag/__pycache__/llm.cpython-311.pyc
  15. BIN
      rag/__pycache__/load_model.cpython-310.pyc
  16. BIN
      rag/__pycache__/load_model.cpython-311.pyc
  17. 223 0
      rag/chat_message.py
  18. 440 0
      rag/db.py
  19. BIN
      rag/document_load/__pycache__/image_load.cpython-310.pyc
  20. BIN
      rag/document_load/__pycache__/office_load.cpython-310.pyc
  21. BIN
      rag/document_load/__pycache__/pdf_load.cpython-310.pyc
  22. BIN
      rag/document_load/__pycache__/pdf_load.cpython-311.pyc
  23. BIN
      rag/document_load/__pycache__/txt_load.cpython-310.pyc
  24. BIN
      rag/document_load/__pycache__/txt_load.cpython-311.pyc
  25. BIN
      rag/document_load/__pycache__/word_load.cpython-310.pyc
  26. 26 0
      rag/document_load/image_load.py
  27. 26 0
      rag/document_load/office_load.py
  28. 273 0
      rag/document_load/pdf_load.py
  29. 9 0
      rag/document_load/txt_load.py
  30. 403 0
      rag/documents_process.py
  31. 184 0
      rag/file_process.py
  32. 187 0
      rag/llm.py
  33. 16 0
      rag/load_model.py
  34. 0 0
      rag/vector_db/__init__.py
  35. BIN
      rag/vector_db/__pycache__/__init__.cpython-310.pyc
  36. BIN
      rag/vector_db/__pycache__/__init__.cpython-311.pyc
  37. BIN
      rag/vector_db/__pycache__/milvus_vector.cpython-310.pyc
  38. BIN
      rag/vector_db/__pycache__/milvus_vector.cpython-311.pyc
  39. 507 0
      rag/vector_db/milvus_vector.py
  40. 146 0
      rag_server.py
  41. 104 0
      requirements.txt
  42. 85 0
      response_info.py
  43. 0 0
      utils/__init__.py
  44. BIN
      utils/__pycache__/__init__.cpython-310.pyc
  45. BIN
      utils/__pycache__/__init__.cpython-311.pyc
  46. BIN
      utils/__pycache__/get_logger.cpython-310.pyc
  47. BIN
      utils/__pycache__/get_logger.cpython-311.pyc
  48. BIN
      utils/__pycache__/upload_file_to_oss.cpython-310.pyc
  49. BIN
      utils/__pycache__/upload_file_to_oss.cpython-311.pyc
  50. 40 0
      utils/get_logger.py
  51. BIN
      utils/simsun.ttc
  52. 75 0
      utils/upload_file_to_oss.py

+ 77 - 0
config.py

@@ -0,0 +1,77 @@
+# 测试环境
+
+# 本地milvus
+milvus_uri = "http://127.0.0.1:19530"
+
+# 测试环境 mysql 数据库配置
+mysql_config = {
+        "host": "xia0miduo.gicp.net",
+        "port": 3336,
+        "user": "root",
+        "password": "T@kai2025",
+        "database": "chat_deepseek",
+        # "database": "rag_master",
+}
+
+# 测试minio配置
+minio_config = {
+    "minio_endpoint" : 'xia0miduo.gicp.net:9000',
+    "minio_access_key" : 'fileadmin',
+    "minio_secret_key" : 'fileadmin',
+    "minio_bucket" : 'papbtest',
+    "minio_url": "http://xia0miduo.gicp.net:9000",
+     "flag": False
+}
+
+# 测试环境vllm 链接
+# vllm_url = "http://xia0miduo.gicp.net:8102/v1"
+model_name_vllm_url_dict = {
+    "DeepSeek-R1-Distill-Qwen-14B": "http://xia0miduo.gicp.net:8102/v1"
+}
+
+# 测试环境redis 配置
+redis_config_dict = {
+    "host": "localhost",
+    "port": 6379,
+    "db": 1
+} 
+
+
+
+# 线上环境
+
+# # 线上环境 mysql 数据库配置
+# mysql_config = {
+#         "host": "127.0.0.1",
+#         "port": 3306,
+#         "user": "root",
+#         "password": "Lx304307910",
+#         "database": "chat_deepseek",
+# }
+
+# # 线上milvus配置
+# milvus_uri = "http://127.0.0.1:19530"
+
+# # 线上minio配置
+# minio_config = {
+#         "minio_endpoint" : 'minio.ryuiso.com:59000',
+#         "minio_access_key" : 'oss_library',
+#         "minio_secret_key" : 'yDkG9YJiC92G3vk52goST',
+#         "minio_bucket" : 'deepseek-doc',
+#         "minio_url": "https://minio.ryuiso.com:59000",
+#         "flag": True
+# }
+
+# # 线上vllm 链接
+# # vllm_url = "http://10.1.27.6:11817/v1"
+# model_name_vllm_url_dict = {
+#     "DeepSeek-R1-Distill-Llama-70B": "http://10.1.27.6:11817/v1",
+#     "Qwen2-72B": "http://10.1.27.6:11818/v1",
+# }
+
+# # 线上环境redis 配置
+# redis_config_dict = {
+#     "host": "localhost",
+#     "port": 6379,
+#     "db": 1
+# } 

+ 0 - 0
network_search/search.py


+ 0 - 0
rag/__init__.py


BIN
rag/__pycache__/__init__.cpython-310.pyc


BIN
rag/__pycache__/__init__.cpython-311.pyc


BIN
rag/__pycache__/chat_message.cpython-310.pyc


BIN
rag/__pycache__/chat_message.cpython-311.pyc


BIN
rag/__pycache__/db.cpython-310.pyc


BIN
rag/__pycache__/db.cpython-311.pyc


BIN
rag/__pycache__/documents_process.cpython-310.pyc


BIN
rag/__pycache__/file_process.cpython-310.pyc


BIN
rag/__pycache__/file_process.cpython-311.pyc


BIN
rag/__pycache__/llm.cpython-310.pyc


BIN
rag/__pycache__/llm.cpython-311.pyc


BIN
rag/__pycache__/load_model.cpython-310.pyc


BIN
rag/__pycache__/load_model.cpython-311.pyc


+ 223 - 0
rag/chat_message.py

@@ -0,0 +1,223 @@
+from rag.db import MilvusOperate
+from rag.load_model import *
+from rag.llm import VllmApi
+from config import redis_config_dict
+import json
+import re
+import gc
+import redis
+from utils.get_logger import setup_logger
+logger = setup_logger(__name__)
+
+rerank_model_mapping = {
+    "bce_rerank_model": (bce_rerank_tokenizer, bce_rerank_base_model),
+    "rerank": (bce_rerank_tokenizer, bce_rerank_base_model)
+}
+redis_host = redis_config_dict.get("host")
+redis_port = redis_config_dict.get("port")
+redis_db = redis_config_dict.get("db")
+
+class ChatRetrieverRag:
+    def __init__(self, chat_json: dict = None, chat_id: str=None):
+        self.chat_id = chat_id
+        self.redis_client = redis.StrictRedis(host=redis_host, port=redis_port, db=redis_db)
+        if not chat_id:
+            self.vllm_client = VllmApi(chat_json)
+
+    def rerank_result(self,query, top_k, hybrid_search_result, rerank_embedding_name):
+        rerank_list = []
+        for result in hybrid_search_result:
+            rerank_list.append([query, result["content"]])
+
+        tokenizer, rerank_model = rerank_model_mapping.get(rerank_embedding_name)
+        # 重排序
+        rerank_model.eval()
+        inputs = tokenizer(rerank_list, padding=True, truncation=True, max_length=512, return_tensors="pt")
+        inputs = {k: v.to(device) for k, v in inputs.items()}
+        with torch.no_grad():
+            logits = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
+        logits = logits.detach().cpu()
+        scores = torch.sigmoid(logits).tolist()
+        # scores_list = scores.tolist()
+        logger.info(f"重排序的得分:{scores}")
+
+        sorted_pairs = sorted(zip(scores, hybrid_search_result), key=lambda x: x[0], reverse=True)
+        sorted_scores, sorted_search = zip(*sorted_pairs)
+        sorted_scores_list = list(sorted_scores)
+        sorted_search_list = list(sorted_search)
+        for score, search in zip(sorted_scores_list, sorted_search_list):
+            search["rerank_score"] = score
+        del inputs, logits
+        gc.collect()
+        torch.cuda.empty_cache()
+        return sorted_search_list
+    
+    def retriever_result(self, chat_json):
+        mode_search_mode = {
+            "embedding": "dense",
+            "keyword": "sparse",
+            "mixed": "hybrid",
+        }
+        collection_name = chat_json.get("knowledgeIds")  # 
+        top = chat_json.get("sliceCount", 5)
+        retriever_info = json.loads(chat_json.get("knowledgeInfo", "{}"))
+        rag_embedding_name = chat_json.get("embeddingId")
+        mode_rag = retriever_info.get("recall_method")
+        mode = mode_search_mode.get(mode_rag, "hybrid")
+        rerank_embedding_name = retriever_info.get("rerank_model_name")
+        rerank_status = retriever_info.get("rerank_status")
+        query = chat_json.get("query")
+
+        hybrid_search_result = MilvusOperate(
+            collection_name=collection_name,embedding_name=rag_embedding_name)._search(query, k=top, mode=mode)
+        logger.info(f"根据{collection_name}检索到结果:{hybrid_search_result}")
+        if len(hybrid_search_result) <= 0:
+            rerank_result_list = []
+        elif rerank_status:
+            rerank_result_list = self.rerank_result(query, top, hybrid_search_result, rerank_embedding_name)
+        else:
+            for result in hybrid_search_result:
+                result["rerank_score"] = 1
+            rerank_result_list = hybrid_search_result
+        return rerank_result_list
+
+
+    async def generate_rag_response(self, chat_json, chunk_content):
+        # logger.info(f"rag聊天的请求参数:{chat_json}")
+        # retriever_result_list = self.retriever_result(chat_json)
+        # logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
+        prompt = chat_json.get("prompt")
+        query = chat_json.get("query")
+
+        # chunk_content = ""
+        # for retriever in retriever_result_list:
+        #     chunk_content += retriever["content"]
+        prompt = prompt.replace("{知识}", chunk_content).replace("{用户}", query)
+        logger.info(f"请求的提示词:{prompt}")
+        temperature = float(chat_json.get("temperature", 0.6))
+        top_p = float(chat_json.get("topP", 0.7))
+        max_token = chat_json.get("maxToken", 4096)
+
+        # 调用模型获取返回的结果
+        model = chat_json.get("model", "DeepSeek-R1-Distill-Qwen-14B")
+        # model = "DeepSeek-R1-Distill-Qwen-14B"
+        chat_resp = ""
+        for chunk in self.vllm_client.chat(prompt, model, temperature=temperature, top_p=top_p, max_tokens=max_token, stream=True, history=[]):
+            if chunk.get("event") == "add":
+                chat_resp += chunk.get("data")
+                chunk["data"] = chat_resp
+                yield chunk
+            else:
+                yield chunk
+            
+    def parse_retriever_list(self, retriever_result_list):
+        doc_info = {}
+        chunk_content = ""
+
+        # 组织成每个doc对应的json格式
+        for retriever in retriever_result_list:
+            chunk_text = retriever["content"]
+            chunk_content += chunk_text
+            doc_id = retriever["doc_id"]
+            chunk_id = retriever["chunk_id"]
+            rerank_score = retriever["rerank_score"]
+            doc_name = retriever["metadata"]["source"]
+            if doc_id in doc_info:
+                c_d = {
+                    "chunk_id": chunk_id,
+                    "rerank_score": rerank_score,
+                    "chunk_len": len(chunk_text)
+                }
+                doc_info[doc_id]["chunk_list"].append(c_d)
+                # doc_info[doc_id]["rerank_score"].append(rerank_score)
+            else:
+                c_d = {
+                    "chunk_id": chunk_id,
+                    "rerank_score": rerank_score,
+                    "chunk_len": len(chunk_text)
+                }
+                doc_info[doc_id] = {
+                    "doc_name": doc_name,
+                    "chunk_list": [c_d],
+                }
+
+        doc_list = []
+        for k, v in doc_info.items():
+            d = {}
+            d["doc_id"] = k
+            d["doc_name"] = v.get("doc_name")
+            d["chunk_nums"] = len(v.get("chunk_list"))
+            d["chunk_info_list"] = v.get("chunk_list")
+            # d["chunk_len"] = len(v.get("chunk_id_list"))
+            doc_list.append(d)
+
+        return chunk_content, doc_list
+
+    async def generate_event(self, chat_json, request):
+        chat_id = ""
+        try:
+            logger.info(f"rag聊天的请求参数:{chat_json}")
+            knowledge_id = chat_json.get("knowledgeIds")
+            retriever_result_list = self.retriever_result(chat_json)
+            logger.info(f"向量库中获取的最终结果:{retriever_result_list}")
+            chunk_content, doc_list = self.parse_retriever_list(retriever_result_list)
+
+            first = True
+            async for event in self.generate_rag_response(chat_json, chunk_content):
+                chat_id = event.get("id")
+                if first:
+                    json_dict = {"knowledge_id": knowledge_id, "doc": doc_list}
+                    self.redis_client.set(chat_id, json.dumps(json_dict))
+                    # logger.info(f"返回检索出的切片信息:{json_dict}")
+                    # yield {"id": chat_id, "event": "json", "data": json_dict}
+                    first = False
+                yield event
+                # yield json.dumps(event, ensure_ascii=False)
+                if await request.is_disconnected():
+                    logger.info(f"chat id:{chat_id}连接中断")
+                    yield {"id": chat_id, "event": "interrupted", "data": ""}
+                    return
+                    # yield json.dumps({"id": chat_id, "event": "interrupted", "data": ""}, ensure_ascii=False)
+        except Exception as e:
+            logger.info(f"执行出错:{e}")
+            yield {"id": chat_id, "event": "finish", "data": ""}
+            return
+            # yield json.dumps({"id": chat_id, "event": "finish", "data": ""}, ensure_ascii=False)
+
+    async def generate_relevant_query(self, query_json):
+        messages = query_json.get("messages")
+        model = query_json.get("model")
+
+        query_result = self.vllm_client.chat(model=model, stream=False, history=messages)
+        
+        for result in query_result:
+            # result_json = json.loads(result)
+            # logger.info(f"生成的问题:{result_json}")
+            # result_str = result_json.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
+            result = result.strip()
+            logger.info(f"模型生成的问题:{result}")
+            try:
+                if "```json" in result:
+                    json_pattern = r'```json\s(.*?)```'
+                    matches = re.findall(json_pattern, result, re.DOTALL)
+                    result = matches[0]
+                query_json = json.loads(result)
+            except Exception as e:
+                query_json = eval(result)
+            query_list = query_json.get("问题")
+
+        return {"code": 200, "data": query_list}
+
+    async def search_slice(self):
+        try:
+            chunk_redis_str = self.redis_client.get(self.chat_id)
+            chunk_json = json.loads(chunk_redis_str)
+            chunk_json["code"] = 200
+        except Exception as e:
+            logger.error(f"查询redis报错:{e}")
+            chunk_json = {
+                "code": 500,
+                "message": str(e)
+            }
+        return chunk_json
+

+ 440 - 0
rag/db.py

@@ -0,0 +1,440 @@
+from rag.vector_db.milvus_vector import HybridRetriever
+from response_info import generate_message, generate_response
+from utils.get_logger import setup_logger
+from datetime import datetime
+from uuid import uuid1
+import mysql.connector
+from mysql.connector import pooling, Error
+from concurrent.futures import ThreadPoolExecutor, TimeoutError
+from config import milvus_uri, mysql_config
+
+logger = setup_logger(__name__)
+# uri = "http://localhost:19530"
+
+try:
+    POOL = pooling.MySQLConnectionPool(
+        pool_name="mysql_pool",
+        pool_size=10,
+        **mysql_config
+    )
+    logger.info("MySQL 连接池初始化成功")
+except Error as e:
+    logger.info(f"初始化 MySQL 连接池失败: {e}")
+    POOL = None
+
+
+class MilvusOperate:
+
+    def __init__(self, collection_name: str = "default", embedding_name:str = "e5"):
+        self.collection = collection_name
+        self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name)
+        self.mysql_client = MysqlOperate()
+
+    def _has_collection(self):
+        is_collection = self.hybrid_retriever.has_collection()
+        return is_collection
+    
+    def _create_collection(self):
+        if self._has_collection():
+            resp = {"code": 400, "message": "数据库已存在"}
+        else:
+            create_result = self.hybrid_retriever.build_collection()
+            resp = generate_message(create_result)
+        return resp
+
+    
+    def _delete_collection(self):
+        delete_result = self.hybrid_retriever.delete_collection(self.collection)
+        resp = generate_message(delete_result)
+        return resp
+
+    
+    def _put_by_id(self, slice_json):
+        slice_id = slice_json.get("slice_id", None)
+        slice_text = slice_json.get("slice_text", None)
+        update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text)
+        if update_result.endswith("success"):
+            # 如果成功,更新mysql中知识库总长度和文档长度
+            update_json = {}
+            update_json["knowledge_id"] = slice_json.get("knowledge_id")
+            update_json["doc_id"] = slice_json.get("document_id")
+            update_json["chunk_len"] = chunk_len
+            update_json["operate"] = "update"
+            update_json["chunk_id"] = slice_id
+            update_json["chunk_text"] = slice_text
+            update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
+        else:
+            update_flag = False
+            
+        if not update_flag:
+            update_result = "update_error"    
+        
+        resp = generate_message(update_result)
+        return resp
+    
+    def _insert_slice(self, slice_json):
+        slice_id = str(uuid1())
+        knowledge_id = slice_json.get("knowledge_id")
+        doc_id = slice_json.get("document_id")
+        slice_text = slice_json.get("slice_text", None)
+        doc_name = slice_json.get("doc_name")
+        chunk_len = len(slice_text)
+        metadata = {
+            "content": slice_text,
+            "doc_id": doc_id,
+            "chunk_id": slice_id,
+            "metadata": {"source": doc_name, "chunk_len": chunk_len}
+        }
+        insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata)
+        if insert_flag:
+            # 如果成功,更新mysql中知识库总长度和文档长度
+            update_json = {}
+            update_json["knowledge_id"] = slice_json.get("knowledge_id")
+            update_json["doc_id"] = slice_json.get("document_id")
+            update_json["chunk_len"] = chunk_len
+            update_json["operate"] = "insert"
+            update_json["chunk_id"] = slice_id
+            update_json["chunk_text"] = slice_text
+            update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
+        else:
+            logger.error(f"插入向量库出错:{insert_str}")
+            update_flag = False
+            update_str = "向量库写入出错"
+            # pass
+            
+        if not update_flag:
+            logger.error(f"新增切片中mysql数据库出错:{update_str}")
+            insert_result = "insert_error"
+        else:
+            insert_result = "insert_success"
+        
+        resp = generate_message(insert_result)
+        return resp
+    
+
+    def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id):
+        logger.info(f"删除的切片id:{chunk_id}")
+        delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id)
+        if delete_result.endswith("success"):
+            chunk_len = delete_chunk_len[0]
+            update_json = {
+                "knowledge_id": knowledge_id,
+                "doc_id": document_id,
+                "chunk_len": -chunk_len,
+                "operate": "delete",
+                "chunk_id": chunk_id
+            }
+            update_flag, update_str = self.mysql_client.update_total_doc_len(update_json)
+        else:
+            logger.error("根据chunk id删除向量库失败")
+            update_flag = False
+            update_str = "根据chunk id删除失败"
+        
+        if not update_flag:
+            logger.error(update_str)
+            delete_result = "delete_error"
+
+        
+        resp = generate_message(delete_result)
+        return resp
+    
+    def _delete_by_doc_id(self, doc_id: str = None):
+        logger.info(f"删除数据的id:{doc_id}")
+        delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id)
+        resp = generate_message(delete_result)
+        return resp
+    
+
+    def _search_by_chunk_id(self, chunk_id):
+        if self._has_collection():
+            query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id)
+        else:
+            query_result = []
+        logger.info(f"根据切片查询到的信息:{query_result}")
+        resp = generate_response(query_result)
+
+        return resp
+    
+    
+    def _search_by_key_word(self, search_json):
+        if self._has_collection():
+            doc_id = search_json.get("document_id", None)
+            text = search_json.get("text", None)
+            page_num = search_json.get("pageNum", 1)
+            page_size = search_json.get("pageSize", 10)
+            page_num = search_json.get("pageNum")  # 根据传过来的id处理对应知识库
+            query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text)
+        else:
+            query_result = []
+        resp = generate_response(query_result,page_num,page_size)
+
+        return resp
+    
+    def _insert_data(self, docs):
+        for doc in docs:
+            chunk = doc.get("content")
+            insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc)
+            if not insert_flag:
+                break
+        resp = insert_flag
+        return resp, insert_info
+    
+    def _batch_insert_data(self, docs, text_lists):
+        insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs)
+
+        resp = insert_flag
+        return resp, insert_info
+
+    def _search(self, query, k, mode):
+        search_result = self.hybrid_retriever.search(query, k, mode)
+        return search_result
+
+
+class MysqlOperate:
+
+    def get_connection(self):
+        """
+        从连接池中获取一个连接
+        :return: 数据库连接对象
+        """
+        try:
+            with ThreadPoolExecutor() as executor:
+                future = executor.submit(POOL.get_connection)
+                connection = future.result(timeout=5.0)  # 设置超时时间为5秒
+
+                logger.info("成功从连接池获取连接")
+                return connection, "success"
+        except TimeoutError:
+            logger.error("获取mysql数据库连接池超时")
+            return None, "mysql获取连接池超时"
+        except Error as e:
+            logger.error(f"无法从连接池获取连接: {e}")
+            return None, str(e)
+
+    def insert_to_slice(self, docs, knowledge_id, doc_id):
+        """
+        插入数据到切片信息表中 slice_info
+        """
+        connection = None
+        cursor = None
+        date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+        values = []
+        connection, cennction_info = self.get_connection()
+        if not connection:
+            return False, cennction_info
+        
+        for chunk in docs:
+            slice_id = chunk.get("chunk_id")
+            slice_text = chunk.get("content")
+            chunk_index = chunk.get("metadata").get("chunk_index")
+            values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index))
+        try:
+            cursor = connection.cursor()
+            insert_sql = """
+                INSERT INTO slice_info (
+                    slice_id,
+                    knowledge_id,
+                    document_id,
+                    slice_text,
+                    create_time,
+                    slice_index
+                ) VALUES (%s, %s, %s, %s, %s,%s)
+                """
+            
+            cursor.executemany(insert_sql, values)
+            connection.commit()
+            logger.info(f"批量插入切片数据成功。")
+            return True, "success"
+
+        except Error as e:
+            logger.error(f"数据库操作出错:{e}")
+            connection.rollback()
+            return False, str(e)
+        finally:
+            # if cursor:
+            cursor.close()
+            # if connection and connection.is_connected():
+            connection.close()
+
+    def delete_to_slice(self, doc_id):
+        """
+        删除 slice_info库中切片信息
+        """
+        connection = None
+        cursor = None
+        connection, connection_info = self.get_connection()
+        if not connection:
+            return False, connection_info
+        try:
+            cursor = connection.cursor()
+            delete_sql = f"DELETE FROM slice_info WHERE document_id = %s"
+            cursor.execute(delete_sql, (doc_id,))
+            connection.commit()
+            logger.info(f"删除数据成功")
+            return True, "success"
+
+        except Error as e:
+            logger.error(f"根据{doc_id}删除数据失败:{e}")
+            connection.rollback()
+            return False, str(e)
+        finally:
+            # if cursor:
+            cursor.close()
+            # if connection and connection.is_connected():
+            connection.close()
+
+    def insert_to_image_url(self, image_dict, knowledge_id, doc_id):
+        """
+        批量插入数据到指定表
+        """
+        connection = None
+        cursor = None
+        connection, connection_info = self.get_connection()
+        if not connection:
+            return False, connection_info
+        
+        date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+        values = []
+        for img_key, img_value in image_dict.items():
+            origin_text = img_key
+            media_url = img_value
+            values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now))
+        try:
+            cursor = connection.cursor()
+            insert_sql = """
+                INSERT INTO bm_media_replacement (
+                    knowledge_id,
+                    document_id,
+                    origin_text,
+                    media_type,
+                    media_url,
+                    create_time
+                ) VALUES (%s, %s, %s, %s, %s, %s)
+                """
+            cursor.executemany(insert_sql, values)
+            connection.commit()
+            logger.info(f"插入到bm_media_replacement表成功")
+            return True, "success"
+        except Error as e:
+            logger.error(f"数据库操作出错:{e}")
+            connection.rollback()
+            return False, str(e)
+        finally:
+            # if cursor:
+            cursor.close()
+            # if connection and connection.is_connected():
+            connection.close()
+
+    def delete_image_url(self, doc_id):
+        """
+        根据doc id删除bm_media_replacement中的数据
+        """
+        connection = None
+        cursor = None
+        connection, connection_info = self.get_connection()
+        if not connection:
+            return False, connection_info
+        
+        try:
+            cursor = connection.cursor()
+            delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s"
+            cursor.execute(delete_sql, (doc_id,))
+            connection.commit()
+            logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功")
+            return True, "success"
+        except Error as e:
+            logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}")
+            connection.rollback()
+            return False, str(e)
+        finally:
+            # if cursor:
+            cursor.close()
+            # if connection and connection.is_connected():
+            connection.close()
+
+    def update_total_doc_len(self, update_json):
+        """
+        更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息
+        """
+        knowledge_id = update_json.get("knowledge_id")
+        doc_id = update_json.get("doc_id")
+        chunk_len = update_json.get("chunk_len")
+        operate = update_json.get("operate")
+        chunk_id = update_json.get("chunk_id")
+        chunk_text = update_json.get("chunk_text")
+        connection = None
+        cursor = None
+        connection, connection_info = self.get_connection()
+        if not connection:
+            return False, connection_info
+        try:
+            cursor = connection.cursor()
+            query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s"
+            query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s"
+            cursor.execute(query_doc_word_num_sql, (doc_id,))
+            doc_result = cursor.fetchone()
+            logger.info(f"查询到的文档长度信息:{doc_result}")
+            cursor.execute(query_knowledge_word_num_sql, (knowledge_id, ))
+            knowledge_result = cursor.fetchone()
+            logger.info(f"查询到的知识库总长度信息:{knowledge_result}")
+            if not doc_result:
+                new_word_num = 0
+                slice_total = 0
+            else:
+                old_word_num = doc_result[0]
+                slice_total = doc_result[1]
+                new_word_num = old_word_num + chunk_len
+                slice_total -= 1 if slice_total else 0
+
+            if not knowledge_result:
+                new_knowledge_word_num = 0
+            else:
+                old_knowledge_word_num = knowledge_result[0]
+                new_knowledge_word_num = old_knowledge_word_num + chunk_len
+
+            if operate == "update":
+                update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
+                cursor.execute(update_sql, (new_word_num, doc_id))
+
+                date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+                update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s"
+                cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id))
+            elif operate == "insert":
+                query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s"
+                cursor.execute(query_slice_info_index_sql, (doc_id,))
+                chunk_index_result = cursor.fetchone()[0]
+                # logger.info(chunk_index_result)
+                if chunk_index_result:
+                    chunk_max_index = int(chunk_index_result)
+                else:
+                    chunk_max_index = 0
+
+                update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s"
+                cursor.execute(update_sql, (new_word_num, doc_id))
+
+                date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+                insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)"
+                cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1))
+            else:
+                update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s"
+                cursor.execute(update_sql, (new_word_num, slice_total, doc_id))
+
+                # 删除切片id对应的切片  
+                delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s"
+                cursor.execute(delete_slice_sql, (chunk_id, ))
+
+            update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s"
+            cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id))
+
+            connection.commit()
+            logger.info("bm_document和bm_knowledge数据更新成功")
+            return True, "success"
+        except Error as e:
+            logger.error(f"数据库操作出错:{e}")
+            connection.rollback()
+            return False, str(e)
+        finally:
+            # if cursor:
+            cursor.close()
+            # if connection and connection.is_connected():
+            connection.close()

BIN
rag/document_load/__pycache__/image_load.cpython-310.pyc


BIN
rag/document_load/__pycache__/office_load.cpython-310.pyc


BIN
rag/document_load/__pycache__/pdf_load.cpython-310.pyc


BIN
rag/document_load/__pycache__/pdf_load.cpython-311.pyc


BIN
rag/document_load/__pycache__/txt_load.cpython-310.pyc


BIN
rag/document_load/__pycache__/txt_load.cpython-311.pyc


BIN
rag/document_load/__pycache__/word_load.cpython-310.pyc


+ 26 - 0
rag/document_load/image_load.py

@@ -0,0 +1,26 @@
+import os
+
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.data.read_api import read_local_images
+
+
+class MinerUParseImage():
+    # def __init__(self, knowledge_id):
+    #     self.knowledge_id = knowledge_id
+
+    async def extract_text(self, file_path):
+        local_image_dir = "./tmp_file/images"
+        image_dir = str(os.path.basename(local_image_dir))
+
+        os.makedirs(local_image_dir, exist_ok=True)
+
+        image_writer = FileBasedDataWriter(local_image_dir)
+
+
+        ds = read_local_images(file_path)[0]  # 
+        infer_result = ds.apply(doc_analyze, ocr=True)
+        pipe_result = infer_result.pipe_ocr_mode(image_writer)
+        content_list_content = pipe_result.get_content_list(image_dir)
+
+        return content_list_content

+ 26 - 0
rag/document_load/office_load.py

@@ -0,0 +1,26 @@
+import os
+
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.data.read_api import read_local_office
+
+
+class MinerUParseOffice():
+    # def __init__(self, knowledge_id):
+    #     self.knowledge_id = knowledge_id
+
+    async def extract_text(self, file_path):
+        local_image_dir = "./tmp_file/images"
+        image_dir = str(os.path.basename(local_image_dir))
+
+        os.makedirs(local_image_dir, exist_ok=True)
+
+        image_writer = FileBasedDataWriter(local_image_dir)
+
+
+        ds = read_local_office(file_path)[0]  # 
+        infer_result = ds.apply(doc_analyze, ocr=True)
+        pipe_result = infer_result.pipe_ocr_mode(image_writer)
+        content_list_content = pipe_result.get_content_list(image_dir)
+
+        return content_list_content

+ 273 - 0
rag/document_load/pdf_load.py

@@ -0,0 +1,273 @@
+import fitz  # PyMuPDF
+import os
+from PIL import Image
+import io
+import pdfplumber
+from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
+from utils.upload_file_to_oss import UploadMinio
+from config import minio_config
+
+import os
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.config.enums import SupportedPdfParseMethod
+
+
+class PDFLoader(UnstructuredFileLoader):
+    def __init__(self, file_json):
+        self.base_path = "./tmp_file"
+        self.file_json = file_json
+        self.flag = self.file_json.get("flag")  # 后续优化
+        self.file_path_process()
+        if self.flag == "update":
+            self.flag_image_info_dict = {}
+            if not self.output_pdf_path:
+                self.upload_minio = UploadMinio()
+                self.image_positions_dict = self.get_image_positions()
+                self.images_path_dict, self.flag_image_info_dict = self.save_images()
+                self.replace_images_with_text()
+
+        else:
+            self.upload_minio = UploadMinio()
+            self.image_positions_dict = self.get_image_positions()
+            self.images_path_dict, self.flag_image_info_dict = self.save_images()
+            self.replace_images_with_text()
+
+    def file_path_process(self):
+        self.knowledge_id = self.file_json.get("knowledge_id")
+        self.document_id = self.file_json.get("document_id")
+        know_path = self.base_path + f"/{self.knowledge_id}"
+        
+        self.file_name = self.file_json.get("name")
+        self.output_pdf_name = "output_" + self.file_name
+        self.input_pdf_path = os.path.join(know_path, self.file_name)
+        self.output_pdf_path = os.path.join(know_path, self.output_pdf_name)
+        self.file_name_list = self.file_name.split(".")
+        self.image_dir = ".".join(self.file_name_list[:-1])
+        self.save_image_path = know_path + "/" + self.document_id
+    
+    def get_image_positions(self):
+        images_dict = {}
+        with pdfplumber.open(self.input_pdf_path) as pdf:
+            page_num = 0
+            for page in pdf.pages:
+                images_dict[page_num] = {}
+                image_num = 0
+                img_list = {}
+                img_list[image_num] = {}
+                for image in page.images:
+                    #print("Image position:", image)
+                    img_list[image_num] = {"x0":image['x0'],"y0":image['y0']}
+                    image_num += 1
+                    img_list[image_num] = {}
+                images_dict[page_num]=img_list
+                page_num += 1
+        # print(f"images list info: {images_dict}")
+        return images_dict
+    
+    def save_images(self):
+        # 创建图片保存目录
+        os.makedirs(self.save_image_path, exist_ok=True)
+        
+        # 使用PyMuPDF打开PDF文件
+        doc = fitz.open(self.input_pdf_path)
+        all_images_dict = {}
+        pdf_img_index = 1
+        flag_img_info = {}
+        for page_num in range(len(doc)):
+            page = doc.load_page(page_num)
+            images = page.get_images(full=True)
+            page_image_dict = {}
+            
+            for img_index, img in enumerate(images):
+                xref = img[0]  # 图片的XRef编号
+                base_image = doc.extract_image(xref)
+                image_bytes = base_image["image"]
+                
+                # 将字节数据转换为PIL图像
+                pil_image = Image.open(io.BytesIO(image_bytes))
+                
+                # 生成唯一文件名
+                # img_name = f"page{page_num+1}_img{img_index+1}.{base_image['ext']}"
+                img_name = f"{self.document_id}_{pdf_img_index}.{base_image['ext']}"
+                img_path = os.path.join(self.save_image_path, img_name)
+                
+                # page_image_dict[img_index] = img_path
+                # 保存成image_name
+                image_str = self.knowledge_id + "/" + self.document_id + "/" + img_name
+                replace_text = f"【示意图序号_{self.document_id}_{pdf_img_index}】"
+                page_image_dict[img_index] = replace_text  # 替换pdf中的文字
+
+                # 保存图片
+                pil_image.save(img_path)
+
+                # 保存的图片上传的oss
+                self.upload_minio.upload_file(img_path, f"/pdf/{image_str}")
+                minio_url = minio_config.get("minio_url")
+                minio_bucket = minio_config.get("minio_bucket")
+                flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}//pdf/{image_str}"
+                pdf_img_index += 1
+                
+            all_images_dict[page_num] = page_image_dict
+        
+        # 关闭原始文档
+        doc.close()
+
+        return all_images_dict, flag_img_info
+    
+    def replace_images_with_text(self):
+        # 打开原始PDF
+        doc = fitz.open(self.input_pdf_path)
+        
+        # 设置字体大小
+        font_size = 12
+        font_name = "SimSun"
+        font_path = r"./utils/simsun.ttc"  # 当前系统中的字体路径
+
+        # 遍历每一页
+        for page_num in range(len(doc)):
+            page = doc.load_page(page_num)  # 获取页面
+            
+            images = page.get_images(full=True)  # 获取页面中的所有图片
+            page_height =  page.rect.height
+            # print("page_height: ", page_height)
+            
+            for img_index, img in enumerate(images):
+                xref = img[0]  # 图片的XRef编号
+                base_image = doc.extract_image(xref)  # 提取图片
+    
+                bbox = fitz.Rect(img[1:5])
+                # print("bbox: ", bbox)
+
+                # 删除图片
+                # page.delete_xref(xref)  # 删除图片
+                doc._deleteObject(img[0])
+                
+                # 准备替换文本
+                # replacement_text = f"page{page_num+1}_img{img_index+1}.png"
+                replacement_text = self.images_path_dict[page_num][img_index]
+                print(f"替换的文本:{replacement_text}")
+                
+                # 在删除的图片位置插入文本
+                try:
+                    
+                    x0 = self.image_positions_dict[page_num][img_index]['x0']
+                    y0 = page_height - self.image_positions_dict[page_num][img_index]['y0']
+
+                    # 插入文本坐标
+                    print(f"x0: {x0}, y0: {y0}")
+
+                    # 使用fitz中自带的字体 china-s 效果显示不友好,插入的字体一行铺满 fontname="china-s",
+                    page.insert_text((x0,y0), replacement_text,fontname=font_name, fontfile=font_path, fontsize=font_size, color=(0, 0, 0))
+                    #page.insert_text((x,y+y1), replacement_text, fontsize=font_size, color=(0, 0, 0))
+                except Exception as e:
+                    print(f"Error inserting text for image on page {page_num + 1}: {e}")
+            
+
+        # 保存修改后的PDF
+        doc.save(self.output_pdf_path)
+        doc.close()
+        print(f"Processed PDF saved to: {self.output_pdf_path}")
+
+    def file2text(self):
+        pdf_text = ""
+        with fitz.open(self.output_pdf_path) as doc:
+            for i, page in enumerate(doc):
+                text = page.get_text("text").strip()
+                lines = text.split("\n")
+                if len(lines) > 0 and lines[-1].strip().isdigit():
+                    text = "\n".join(lines[:-1])  # 移除最后一行
+
+                if len(lines) > 0 and lines[0].strip().isdigit():
+                    text = "\n".join(lines[1:])  # 移除第一行
+                # print(f"page text:{text.strip()}")
+                # pdf_text += text + "\n"
+                pdf_text += text
+        # print(pdf_text)
+        return pdf_text, self.flag_image_info_dict
+
+
+
+class MinerUParsePdf():
+    # def __init__(self, knowledge_id, minio_client):
+    #     self.knowledge_id = knowledge_id
+    #     self.minio_client = minio_client
+        
+    async def extract_text(self, file_path):
+        # pdf_file_name = file_path  
+        # prepare env
+        # local_image_dir = f"./tmp_file/{self.knowledge_id}/{doc_id}"
+        local_image_dir = f"./tmp_file/images"
+        image_dir = str(os.path.basename(local_image_dir))
+
+        os.makedirs(local_image_dir, exist_ok=True)
+
+        image_writer = FileBasedDataWriter(local_image_dir)
+
+        # read bytes
+        reader1 = FileBasedDataReader("")
+        pdf_bytes = reader1.read(file_path)  # read the pdf content
+
+        # proc
+        ## Create Dataset Instance
+        ds = PymuDocDataset(pdf_bytes)
+        infer_result = ds.apply(doc_analyze, ocr=True)
+
+        ## pipeline
+        pipe_result = infer_result.pipe_ocr_mode(image_writer)
+        content_list_content = pipe_result.get_content_list(image_dir)
+
+        # image_num = 1
+        # text = ""
+        # flag_img_info = {}
+        # current_page = ""
+        # for i,content_dict in enumerate(content_list_content):
+        #     page_index = content_dict.get("page_idx")
+        #     if i == 0:
+        #         current_page = page_index
+        #     elif page_index != current_page:
+        #         text += "<page>"
+        #         current_page = page_index
+        #     else:
+        #         pass
+
+        #     if content_dict.get("type") == "text":
+        #         content_text = content_dict.get("text")
+        #         text_level = content_dict.get("text_level")
+        #         if text_level:
+        #             text += "#" * text_level + content_text
+        #         else:
+        #             text += content_text
+
+        #     elif content_dict.get("type") in ("image", "table"):
+        #         image_path = content_dict.get("img_path")
+        #         image_name = image_path.split("/")[1]
+        #         save_image_path = local_image_dir + f"/{image_name}"
+        #         replace_text = f"【示意图序号_{doc_id}_{image_num}】"
+        #         minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
+        #         self.minio_client.upload_file(save_image_path, minio_file_path)
+        #         minio_url = minio_config.get("minio_url")
+        #         minio_bucket = minio_config.get("minio_bucket")
+        #         flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
+        #         text += replace_text
+        #         image_num += 1
+            
+        #     else:
+        #         ...
+            
+
+        return content_list_content
+                
+
+if __name__ == "__main__":
+    # input_pdf = r"G:/work/资料/5.1 BMP业务系统使用手册 - 切片.pdf"
+    # output_pdf = "./output.pdf"
+    # image_folder = "./extracted_images"
+    file_json = {
+        "knowledge_id": "1234",
+        "name": "5.1 BMP业务系统使用手册 - 切片.pdf",
+        "document_id": "2222"
+    }
+    loader = PDFLoader(file_json)
+    loader.replace_images_with_text()

+ 9 - 0
rag/document_load/txt_load.py

@@ -0,0 +1,9 @@
+class TextLoad:
+    def __init__(self):
+        pass
+        
+    async def file2text(self, file_path):
+        with open(file_path, "r", encoding="utf-8") as f:
+            content = f.read()
+
+        return content

+ 403 - 0
rag/documents_process.py

@@ -0,0 +1,403 @@
+import aiohttp
+import aiofiles
+from rag.db import MilvusOperate, MysqlOperate
+from rag.document_load.pdf_load import MinerUParsePdf
+from rag.document_load.office_load import MinerUParseOffice
+from rag.document_load.txt_load import TextLoad
+from rag.document_load.image_load import MinerUParseImage
+from utils.upload_file_to_oss import UploadMinio
+from utils.get_logger import setup_logger
+from config import minio_config
+import os
+import time
+from uuid import uuid1
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+
+pdf_parse = MinerUParsePdf()
+office_parse = MinerUParseOffice()
+text_parse = TextLoad()
+image_parse = MinerUParseImage()
+logger = setup_logger(__name__)
+
+
+class ProcessDocuments():
+    def __init__(self, file_json):
+        self.file_json = file_json
+        self.knowledge_id = self.file_json.get("knowledge_id")
+        self.mysql_client = MysqlOperate()
+        self.minio_client = UploadMinio()
+        self.milvus_client = MilvusOperate(collection_name=self.knowledge_id)
+
+    def _get_file_type(self, name):
+        if name.endswith(".txt"):
+            return text_parse
+        elif name.endswith('.pdf'):
+            return pdf_parse
+        elif name.endswith((".doc", ".docx", "ppt", "pptx")):
+            return office_parse
+        elif name.endswith((".jpg", "png", "jpeg")):
+            return image_parse
+        else:
+            raise "不支持的文件格式"
+        
+    async def save_file_temp(self, session, url, name):
+        down_file_path = "./tmp_file" + f"/{self.knowledge_id}"
+        # down_file_path = "./tmp_file"
+        os.makedirs(down_file_path, exist_ok=True)
+
+        down_file_name = down_file_path + f"/{name}"
+        # if os.path.exists(down_file_name):
+        #     pass
+        # else:
+        async with session.get(url, ssl=False) as resp:
+            resp.raise_for_status()
+            content_length = resp.headers.get('Content-Length')
+            if content_length:
+                file_size = int(content_length)
+            else:
+                file_size = 0
+            async with aiofiles.open(down_file_name, 'wb') as f:
+                async for chunk in resp.content.iter_chunked(1024):
+                    await f.write(chunk)
+        
+        return down_file_name, file_size
+
+    def file_split_by_len(self, file_text):
+        split_map = {
+            "0": ["#"],  # 按标题段落切片
+            "1": ["<page>"],  # 按页切片
+            "2": ["\n"]   # 按问答对
+        }
+        separator_num = self.file_json.get("set_slice")
+        slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
+        separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
+        logger.info(f"文本切分字符:{separator}")
+        text_split = RecursiveCharacterTextSplitter(
+            separators=separator,
+            chunk_size=500,
+            chunk_overlap=40,
+            length_function=len
+        )
+        texts = text_split.split_text(file_text)
+        return texts
+    
+    def split_text(self, file_text):
+        text_split = RecursiveCharacterTextSplitter(
+            separators=["\n\n", "\n"],
+            chunk_size=500,
+            chunk_overlap=40,
+            length_function=len
+        )
+        texts = text_split.split_text(file_text)
+        return texts
+
+    
+    def split_by_title(self, file_content_list, set_table, doc_id):
+        # TODO 处理根据标题切分逻辑 图片替换标识符,表格按照set table 0图片,1html数据
+        text_lists = []
+        text = ""
+        image_num = 1
+        flag_img_info = {}
+        level_1_text = ""
+        level_2_text = ""
+        for i, content_dict in enumerate(file_content_list):
+            text_type = content_dict.get("type")
+            content_text = content_dict.get("text")
+            if text_type == "text":
+                text_level = content_dict.get("text_level", "")
+                if text_level == 1:
+                    if not level_1_text:
+                        level_1_text = f"# {content_text}\n"
+                        text += f"# {content_text}\n"
+                    else:
+                        text_lists.append(text)
+                        text = f"# {content_text}\n"
+                        level_1_text = f"# {content_text}\n"
+                        level_2_text = ""
+
+                elif text_level == 2:
+                    if not level_2_text:
+                        text += f"## {content_text}\n"
+                        level_2_text = f"## {content_text}\n"
+                    else:
+                        text_lists.append(text)
+                        text = level_1_text + f"## {content_text}\n"
+                else:
+                    if text_level:
+                        text += text_level*"#" + " " + content_text + "\n"
+                    else:
+                        text += content_text
+
+            elif text_type == "table" and set_table == "1":
+                text += content_dict.get("table_body")
+
+            elif text_type in ("image", "table"):
+                image_path = content_dict.get("img_path")
+                if not image_path:
+                    continue
+                image_name = image_path.split("/")[1]
+                save_image_path = "./tmp_file/images/" + f"/{image_name}"
+                replace_text = f"【示意图序号_{doc_id}_{image_num}】"
+                minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
+                self.minio_client.upload_file(save_image_path, minio_file_path)
+                minio_url = minio_config.get("minio_url")
+                minio_bucket = minio_config.get("minio_bucket")
+                flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
+                text += replace_text
+                image_num += 1
+            if i+1 == len(file_content_list):
+                text_lists.append(text)
+        return text_lists, flag_img_info
+    
+    def split_by_page(self, file_content_list, set_table, doc_id):
+        # TODO 处理按照页面切分,图片处理成标识符,表格按照set table 0图片,1html数据
+        text_lists = []
+        current_page = ""
+        text = ""
+        image_num = 1
+        flag_img_info = {}
+        for i,content_dict in enumerate(file_content_list):
+                page_index = content_dict.get("page_idx")
+                if i == 0:
+                    current_page = page_index
+                elif page_index != current_page:
+                    text_lists.append(text)
+                    text = ""
+                    current_page = page_index
+
+                text_type = content_dict.get("type")
+                if text_type == "text":
+                    content_text = content_dict.get("text")
+                    text_level = content_dict.get("text_level")
+                    if text_level:
+                        text += "#" * text_level + " " + content_text
+                    else:
+                        text += content_text
+
+                elif text_type == "table" and set_table == "1":
+                    text += content_dict.get("table_body")
+
+                elif text_type in ("image", "table"):
+                    image_path = content_dict.get("img_path")
+                    image_name = image_path.split("/")[1]
+                    save_image_path = "./tmp_file/images/" + f"/{image_name}"
+                    replace_text = f"【示意图序号_{doc_id}_{image_num}】"
+                    minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
+                    self.minio_client.upload_file(save_image_path, minio_file_path)
+                    minio_url = minio_config.get("minio_url")
+                    minio_bucket = minio_config.get("minio_bucket")
+                    flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
+                    text += replace_text
+                    image_num += 1
+                if i+1 == len(file_content_list):
+                    text_lists.append(text)
+        return text_lists, flag_img_info
+
+    def split_by_self(self, file_content_list, set_table, slice_value, doc_id):
+        # TODO 按照自定义的符号切分,图片处理成标识符,表格按照set table 0图片,1html数据,长度控制500以内,超过500切断
+        logger.info(f"自定义的分隔符:{slice_value}")
+        text = ""
+        image_num = 1
+        flag_img_info = {}
+        for i, content_dict in enumerate(file_content_list):
+            text_type = content_dict.get("type")
+            if text_type == "text":
+                content_text = content_dict.get("text")
+                text_level = content_dict.get("text_level")
+                if text_level:
+                    text += "#" * text_level + " " + content_text
+                else:
+                    text += content_text
+
+            elif text_type == "table" and set_table == "1":
+                text += content_dict.get("table_body")
+
+            elif text_type in ("image", "table"):
+                image_path = content_dict.get("img_path")
+                image_name = image_path.split("/")[1]
+                save_image_path = "./tmp_file/images/" + f"/{image_name}"
+                replace_text = f"【示意图序号_{doc_id}_{image_num}】"
+                minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
+                self.minio_client.upload_file(save_image_path, minio_file_path)
+                minio_url = minio_config.get("minio_url")
+                minio_bucket = minio_config.get("minio_bucket")
+                flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
+                text += replace_text
+                image_num += 1
+
+        split_lists = text.split(slice_value)
+        text_lists = []
+        for split_text in split_lists:
+            r = len(split_text)//500
+            if r >= 1:
+                for i in range(r+1):
+                    t = split_text[i*500:(i+1)*500]
+                    if t:
+                        text_lists.append(t)
+            else:
+                text_lists.append(split_text)
+                
+        return text_lists, flag_img_info
+
+    def file_split(self, file_content_list, doc_id):
+        # TODO 根据文本列表进行切分 返回切分列表和存储图片的链接
+        separator_num = self.file_json.get("set_slice")
+        set_table = self.file_json.get("set_table")
+        # separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
+        # logger.info(f"文本切分字符:{separator}")
+        if isinstance(file_content_list, str):
+            file_text = file_content_list
+            text_lists = self.split_text(file_text)
+            return text_lists, {}
+
+        elif separator_num == "0":
+            # 使用标题段落切分,使用text_level=1,2 切分即一个# 还是两个#
+            text_lists, flag_img_info = self.split_by_title(file_content_list, set_table, doc_id)
+            return text_lists, flag_img_info
+        elif separator_num == "1":
+            # 按照页面方式切分
+            text_lists, flag_img_info = self.split_by_page(file_content_list, set_table, doc_id)
+            return text_lists, flag_img_info
+        elif separator_num == "2":
+            # 按照问答对切分 针对exce文档,暂不实现
+            return [], {}
+        else:
+            # 自定义切分的方式,按照自定义字符以及文本长度切分,超过500
+            slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
+            text_lists, flag_img_info = self.split_by_self(file_content_list, set_table, slice_value, doc_id)
+            return text_lists, flag_img_info
+
+    
+    def process_data_to_milvus_schema(self, text_lists, doc_id, name):
+        """组织数据格式:
+            {
+                "content": text,
+                "doc_id": doc_id,
+                "chunk_id": chunk_id,
+                "metadata": {"source": file_name},
+            }
+        """
+        docs = []
+        total_len = 0
+        for i, text in enumerate(text_lists):
+            chunk_id = str(uuid1())
+            chunk_len = len(text)
+            total_len += chunk_len
+            d = {
+                "content": text,
+                "doc_id": doc_id,
+                "chunk_id": chunk_id,
+                "metadata": {"source": name, "chunk_index": i+1, "chunk_len": chunk_len}
+            }
+            docs.append(d)
+        return docs, total_len
+    
+    async def process_documents(self, file_json):
+        # 文档下载
+        separator_num = file_json.get("set_slice")
+        if separator_num == "2":
+            return {"code": 500, "message": "暂不支持解析"}
+        docs = file_json.get("docs")
+        flag = file_json.get("flag")
+        success_doc = []  # 记录解析成功的文档id
+        for doc in docs:
+            url = doc.get("url")
+            name = doc.get("name")
+            doc_id = doc.get("document_id")
+            async with aiohttp.ClientSession() as session:
+                down_file_name, file_size = await self.save_file_temp(session, url, name)
+            
+            file_parse = self._get_file_type(name)
+
+            file_content_list = await file_parse.extract_text(down_file_name)
+            logger.info(f"mineru解析的pdf数据:{file_content_list}")
+
+            text_lists, flag_img_info = self.file_split(file_content_list, doc_id)
+            
+            docs, total_char_len = self.process_data_to_milvus_schema(text_lists, doc_id, name)
+            logger.info(f"存储到milvus的文本数据:{docs}")
+            if flag == "upload":
+                # 插入到milvus库中
+                insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
+                
+                if insert_milvus_flag:
+                    # 插入到mysql的slice info数据库中
+                    insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, doc_id)
+                else:
+                    # resp = {"code": 500, "message": insert_milvus_str}
+                    # return resp
+                    insert_slice_flag = False
+                    parse_file_status = False
+
+                if insert_slice_flag:
+                    # 插入mysql中的bm_media_replacement表中
+                    insert_img_flag, insert_mysql_info =  self.mysql_client.insert_to_image_url(flag_img_info, self.knowledge_id, doc_id)
+                else:
+                    # resp = {"code": 500, "message": insert_mysql_info}
+                    self.milvus_client._delete_by_doc_id(doc_id=doc_id)
+                    insert_img_flag = False
+
+                    # return resp
+                    parse_file_status = False
+
+                if insert_img_flag:
+                    # resp = {"code": 200, "message": "文档解析成功"}
+                    parse_file_status = True
+                
+                else:
+                    self.milvus_client._delete_by_doc_id(doc_id=doc_id)
+                    self.mysql_client.delete_image_url(doc_id=doc_id)
+                    # resp = {"code": 500, "message": insert_mysql_info}
+                    parse_file_status = False
+
+                # return resp
+            
+            elif flag == "update":  # 更新切片方式
+                # 先把库中的数据删除
+                self.milvus_client._delete_by_doc_id(doc_id=doc_id)
+                self.mysql_client.delete_to_slice(doc_id=doc_id)
+
+                insert_milvus_start_time = time.time()
+                insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
+                # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists)
+                insert_milvus_end_time = time.time()
+                logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}")
+
+                if insert_milvus_flag:
+                    # 插入到mysql的slice info数据库中
+                    insert_mysql_start_time = time.time()
+                    insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, doc_id)
+                    insert_mysql_end_time = time.time()
+                    logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}")
+                else:
+                    # resp = {"code": 500, "message": insert_milvus_str}
+                    # return resp
+                    insert_slice_flag = False
+                    parse_file_status = False
+                
+                if insert_slice_flag:
+                    # resp = {"code": 200, "message": "切片修改成功"}
+                    parse_file_status = True
+                
+                else:
+                    self.milvus_client._delete_by_doc_id(doc_id=doc_id)
+                    # resp = {"code":500, "message": insert_mysql_info}
+                    parse_file_status = False
+
+                # return resp
+
+            if parse_file_status:
+                success_doc.append(doc_id)
+            else:
+                if flag == "upload":
+                    for del_id in success_doc:
+                        self.milvus_client._delete_by_doc_id(doc_id=del_id)
+                        self.mysql_client.delete_image_url(doc_id=del_id)
+                        self.mysql_client.delete_to_slice(doc_id=del_id)
+
+                return {"code": 500, "message": "解析失败", "knowledge_id" : self.knowledge_id, "doc_info": {}}
+
+        return {"code": 200, "message": "解析成功", "knowledge_id" : self.knowledge_id, "doc_info": {"file_size": file_size, "total_char_len": total_char_len, "slice_num": len(text_lists)}}
+
+
+

+ 184 - 0
rag/file_process.py

@@ -0,0 +1,184 @@
+import requests
+from fastapi import HTTPException
+from typing import List, Dict
+import os
+from uuid import uuid1
+import uuid
+from rag.document_load.pdf_load import PDFLoader
+from rag.document_load.txt_load import TextLoad
+from langchain_text_splitters import RecursiveCharacterTextSplitter
+from rag.db import MilvusOperate, MysqlOperate
+import httpx
+import time
+from utils.get_logger import setup_logger
+
+logger = setup_logger(__name__)
+
+
+file_dict = {
+    "pdf": PDFLoader,
+    # "txt": TextLoad
+}
+
+class ParseFile:
+    def __init__(self, file_json):
+        self.file_json = file_json
+        self.file_name = self.file_json.get("name")
+        # self.file_url = self.file_json
+        self.file_list = self.file_json.get("name").split(".")
+        file_type = self.file_list[1]
+        self.flag = self.file_json.get("flag")
+        self.knowledge_id = self.file_json.get("knowledge_id")
+        self.doc_id = self.file_json.get("document_id")
+        self.save_file_to_tmp()
+        self.load_file = file_dict.get(file_type, PDFLoader)(self.file_json)
+        self.mysql_client = MysqlOperate()
+        self.milvus_client = MilvusOperate(collection_name=self.knowledge_id)
+
+    def save_file_to_tmp(self):
+        # 远程文件存到本地处理
+        url = self.file_json.get("url")
+
+        know_path = "./tmp_file" + f"/{self.knowledge_id}"
+        os.makedirs(know_path, exist_ok=True)
+        tmp_file_name = f"./tmp_file/{self.knowledge_id}/{self.file_name}"
+
+        if self.flag == "upload":
+            file_response = requests.get(url=url)
+            with open(tmp_file_name, "wb") as f:
+                f.write(file_response.content)
+        elif self.flag == "update":
+            if os.path.exists(tmp_file_name):
+                pass
+            else:
+                file_response = requests.get(url=url)
+                with open(tmp_file_name, "wb") as f:
+                    f.write(file_response.content)
+
+        # return file_name
+
+
+    def file_split(self, file_text):
+        split_map = {
+            "0": ["\n"],
+            "1": ["\n"],
+            "2": ["\n"]
+        }
+        separator_num = self.file_json.get("set_slice")
+        slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
+        separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
+        logger.info(f"文本切分字符:{separator}")
+        text_split = RecursiveCharacterTextSplitter(
+            separators=separator,
+            chunk_size=300,
+            chunk_overlap=20,
+            length_function=len
+        )
+        texts = text_split.split_text(file_text)
+
+        return texts
+    
+    def process_data_to_milvus_schema(self, text_lists):
+        """组织数据格式:
+            {
+                "content": text,
+                "doc_id": doc_id,
+                "chunk_id": chunk_id,
+                "metadata": {"source": file_name},
+            }
+        """
+        # doc_id = self.file_json.get("document_id")
+        docs = []
+        for i, text in enumerate(text_lists):
+            chunk_id = str(uuid1())
+            d = {
+                "content": text,
+                "doc_id": self.doc_id,
+                "chunk_id": chunk_id,
+                "metadata": {"source": self.file_name, "chunk_index": i+1}
+            }
+            # d["content"] = text
+            # d["doc_id"] = doc_id
+            # d["chunk_id"] = chunk_id
+            # d["metadata"] = {"source": self.file_name, "chunk_index": i+1}
+            docs.append(d)
+        return docs
+
+    def save_file_to_db(self):
+        # 如果更改切片方式,需要删除对应knowledge id中doc id对应数据
+        flag = self.file_json.get("flag")
+        if flag == "update":
+            # 执行删除操作
+            self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
+            self.mysql_client.delete_to_slice(doc_id=self.doc_id)
+            # self.mysql_client.delete_image_url(doc_id=doc_id)
+        file_text_start_time = time.time()
+        file_text, image_dict = self.load_file.file2text()
+        file_text_end_time = time.time()
+        logger.info(f"pdf加载成文本耗时:{file_text_end_time - file_text_start_time}")
+        text_lists = self.file_split(file_text)
+        file_split_end_time = time.time()
+        logger.info(f"文档切分的耗时:{file_split_end_time - file_text_end_time}")
+        docs = self.process_data_to_milvus_schema(text_lists)
+        logger.info(f"插入milvus的数据:{docs}")
+
+        # doc_id = self.file_json.get("document_id")
+
+        if flag == "upload":
+            # 插入到milvus库中
+            insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
+            
+            if insert_milvus_flag:
+                # 插入到mysql的slice info数据库中
+                insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
+            else:
+                resp = {"code": 500, "message": insert_milvus_str}
+                return resp
+
+            if insert_slice_flag:
+                # 插入mysql中的bm_media_replacement表中
+                insert_img_flag, insert_mysql_info =  self.mysql_client.insert_to_image_url(image_dict, self.knowledge_id, self.doc_id)
+            else:
+                resp = {"code": 500, "message": insert_mysql_info}
+                self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
+
+                return resp
+
+            if insert_img_flag:
+                resp = {"code": 200, "message": "文档解析成功"}
+            
+            else:
+                self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
+                self.mysql_client.delete_image_url(doc_id=self.doc_id)
+                resp = {"code": 500, "message": insert_mysql_info}
+
+            return resp
+
+
+        elif flag == "update":
+            # 插入到milvus库中
+            insert_milvus_start_time = time.time()
+            insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
+            # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists)
+            insert_milvus_end_time = time.time()
+            logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}")
+
+            if insert_milvus_flag:
+                # 插入到mysql的slice info数据库中
+                insert_mysql_start_time = time.time()
+                insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
+                insert_mysql_end_time = time.time()
+                logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}")
+            else:
+                resp = {"code": 500, "message": insert_milvus_str}
+                return resp
+            
+            if insert_slice_flag:
+                resp = {"code": 200, "message": "切片修改成功"}
+            
+            else:
+                self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
+                resp = {"code":500, "message": insert_mysql_info}
+
+            return resp
+

+ 187 - 0
rag/llm.py

@@ -0,0 +1,187 @@
+from openai import OpenAI
+import requests
+import json
+from utils.get_logger import setup_logger
+from config import model_name_vllm_url_dict
+
+logger = setup_logger(__name__)
+
+class VllmApi():
+    def __init__(self, chat_json):
+        openai_api_key = "EMPTY"
+        model = chat_json.get("model")
+        vllm_url = model_name_vllm_url_dict.get(model)
+        openai_api_base = vllm_url
+        self.vllm_chat_url = f"{vllm_url}/chat/completions"
+        self.vllm_generate_url = f"{vllm_url}/completions"
+        self.client = OpenAI(
+            # defaults to os.environ.get("OPENAI_API_KEY")
+            api_key=openai_api_key,
+            base_url=openai_api_base,
+        )
+
+    def chat(self,
+             prompt : str = "",
+             model: str = "deepseek-r1:7b",
+             stream: bool = False,
+             top_p: float = 0.9,
+             temperature: float = 0.6,
+             max_tokens: int = 1024,
+             history: list = []
+             ):
+        if history:
+            messages = history
+        else:
+            messages = [{"role": "user", "content": prompt}]
+        chat_response = self.client.chat.completions.create(
+            model=model,
+            messages=messages,
+            stream=stream,
+            top_p=top_p,
+            temperature=temperature, 
+            max_tokens=max_tokens
+        )
+
+        # 针对deepseek的模型,是否输出think部分
+        yield_reasoning_content = True
+        yield_content = True
+        has_reason = ""
+        if stream:
+            for chunk in chat_response:
+                logger.info(f"vllm返回的chunk信息:{chunk}")
+                reasoning_content = None
+                content = None
+                chat_id = chunk.id
+                # Check the content is reasoning_content or content
+                if chunk.choices[0].delta.role == "assistant":
+                    continue
+                elif hasattr(chunk.choices[0].delta, "reasoning_content"):
+                    reasoning_content = chunk.choices[0].delta.reasoning_content
+                    if reasoning_content:
+                        has_reason += reasoning_content
+                elif hasattr(chunk.choices[0].delta, "content"):
+                    content = chunk.choices[0].delta.content
+
+                if reasoning_content is not None:
+                    if yield_reasoning_content:
+                        yield_reasoning_content = False
+                        reasoning_content = "```think" + reasoning_content
+                        # print("reasoning_content:", end="", flush=True)
+                    # print(reasoning_content, end="", flush=True)
+                    # yield reasoning_content
+                    yield {"id": chat_id, "event": "add", "data": reasoning_content}
+                    
+                elif content is not None:
+                    if yield_content:
+                        yield_content = False
+                        if has_reason:
+                            content = "think```" + content
+                        else:
+                            content = content
+                    #     print("\ncontent:", end="", flush=True) 
+                    # print(content, end="", flush=True)
+                    # yield content
+                    yield {"id": chat_id, "event": "add", "data": content}
+                
+                if chunk.choices[0].finish_reason:
+                    yield {"id": chat_id, "event": "finish", "data": ""}
+        
+        else:
+            # print(f"chat response: {chat_response.model_dump_json()}")
+            yield chat_response.choices[0].message.content
+
+    def generate(self,
+                 prompt: str,
+                 model: str = "deepseek-r1:7b",
+                 history: list = [],
+                 stream: bool = False
+                 ):
+        completion = self.client.completions.create(
+            model=model,
+            prompt=prompt,
+            max_tokens=1024,
+            stream=stream
+        )
+
+        if stream:
+            for chunk in completion:
+                print(f"generate chunk: {chunk}")
+                yield chunk
+        
+        else:
+            return completion
+        
+    def request_generate(self, model, prompt, max_tokens: int = 1024, temperature: float = 0.6, stream: bool = False):
+        json_data = {
+            "model": model,
+            "prompt": prompt,
+            "max_tokens": max_tokens,
+            "temperature": temperature,
+            "stream": stream
+        }
+        response = requests.post(self.vllm_generate_url,json=json_data, stream=stream)
+        response.raise_for_status()
+        if stream:
+            for line in response.iter_lines():
+                if line:
+                    line_str = line.decode("utf-8")
+                    if line_str.startswith("data: "):
+                        json_str = line_str[len("data: "):]
+                    if json_str == "[DONE]":
+                        break
+                    
+                    print(f"返回的数据:{json.loads(json_str)}")
+                    yield json.loads(json_str)
+                
+        else:
+            logger.info(f"直接返回结果:{response.json()}")
+            yield response.json()
+
+    def request_chat(self, 
+                     model, 
+                     prompt, 
+                     history: list = [], 
+                     temperature: float = 0.6, 
+                     stream: bool = False,
+                     top_p: float = 0.7):
+        history.append({"role": "user", "content": prompt})
+        json_data = {
+            "model": model,
+            "messages": history,
+            "temperature": temperature,
+            "stream": stream,
+            "top_p": top_p
+        }
+        response = requests.post(self.vllm_chat_url,json=json_data, stream=stream)
+        response.raise_for_status()
+        if stream:
+            for line in response.iter_lines():
+                if line:
+                    line_str = line.decode("utf-8")
+                    if line_str.startswith("data: "):
+                        json_str = line_str[len("data: "):]
+
+                    if json_str == "[DONE]":
+                        break
+                    
+                    print(f"chat模式返回的数据:{json.loads(json_str)}")
+                    yield json.loads(json_str)
+        else:
+            print(f"聊天模式直接返回结果:{response.json()}")
+            return response.json()
+
+
+def main():
+    history = [{"role": "system", "content": "你是一个非常有帮助的助手,在回答用户问题的时候请以<think>开头。"}]
+    # prompt = "请帮我计算鸡兔同笼的问题。从上面数有35个头,从下面数有94只脚,请问分别多少只兔子多少只鸡?"
+    prompt = "请帮我将下面提供的中文翻译成日文,要求:1、直接输出翻译的结果,2、不要进行任何解释。需要翻译的内容:我下飞机的时候行李丢了。"
+    model = "DeepSeek-R1-Distill-Qwen-14B"
+    vllm_chat_resp = VllmApi().request_chat(prompt=prompt, model=model, history=history, stream=True)
+
+    # print("vllm 回复:")
+    for chunk in vllm_chat_resp:
+        pass
+    #     print(chunk, end='', flush=True)
+
+if __name__=="__main__":
+    main()

+ 16 - 0
rag/load_model.py

@@ -0,0 +1,16 @@
+from pymilvus import model
+import torch
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+# 使用sentence transformer方式加载模型
+# embedding_path = r"/opt/models/multilingual-e5-large-instruct/"  # 线上路径
+embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"  # 本地路径
+sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
+
+# rerank模型
+# bce_rerank_model_path = r"/opt/models/bce-reranker-base_v1"  # 线上路径
+bce_rerank_model_path = r"G:/work/code/models/bce-reranker-base_v1"  # 本地路径
+bce_rerank_tokenizer = AutoTokenizer.from_pretrained(bce_rerank_model_path)
+bce_rerank_base_model = AutoModelForSequenceClassification.from_pretrained(bce_rerank_model_path).to(device)

+ 0 - 0
rag/vector_db/__init__.py


BIN
rag/vector_db/__pycache__/__init__.cpython-310.pyc


BIN
rag/vector_db/__pycache__/__init__.cpython-311.pyc


BIN
rag/vector_db/__pycache__/milvus_vector.cpython-310.pyc


BIN
rag/vector_db/__pycache__/milvus_vector.cpython-311.pyc


+ 507 - 0
rag/vector_db/milvus_vector.py

@@ -0,0 +1,507 @@
+import time
+import numpy as np
+from pymilvus import (
+    MilvusClient,
+    DataType,
+    Function,
+    FunctionType,
+    AnnSearchRequest,
+    RRFRanker,
+)
+# from pymilvus.model.hybrid import BGEM3EmbeddingFunction
+from pymilvus import model
+from rag.load_model import sentence_transformer_ef
+from utils.get_logger import setup_logger
+import torch
+device = "cpu" if torch.cuda.is_available() else "cuda"
+logger = setup_logger(__name__)
+
+# embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
+# sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
+
+
+embedding_mapping = {
+    "e5": sentence_transformer_ef,
+    "multilingual-e5-large-instruct": sentence_transformer_ef,
+}
+
+class HybridRetriever:
+    def __init__(self, uri, embedding_name:str="e5", collection_name:str ="hybrid"):
+        self.uri = uri
+        self.collection_name = collection_name
+        # self.embedding_function = sentence_transformer_ef
+        self.embedding_function = embedding_mapping.get(embedding_name, "e5")
+        self.use_reranker = True
+        self.use_sparse = True
+        self.client = MilvusClient(uri=uri)
+    
+    def has_collection(self):
+        try:
+            collection_flag = self.client.has_collection(self.collection_name)
+            logger.info(f"查询向量库的结果:{collection_flag}")
+        except Exception as e:
+            logger.info(f"查询向量库是否存在出错:{e}")
+            collection_flag = False
+        return collection_flag
+
+    def build_collection(self):
+        if isinstance(self.embedding_function.dim, dict):
+            dense_dim = self.embedding_function.dim["dense"]
+        else:
+            dense_dim = self.embedding_function.dim
+        logger.info(f"创建数据库的向量维度:{dense_dim}")
+        analyzer_params={
+            "type": "chinese"
+        }
+
+        schema = MilvusClient.create_schema()
+        schema.add_field(
+            field_name="pk",
+            datatype=DataType.VARCHAR,
+            is_primary=True,
+            auto_id=True,
+            max_length=100,
+        )
+        schema.add_field(
+            field_name="content",
+            datatype=DataType.VARCHAR,
+            max_length=65535,
+            analyzer_params=analyzer_params,
+            enable_match=True,
+            enable_analyzer=True,
+        )
+        schema.add_field(
+            field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR
+        )
+        schema.add_field(
+            field_name="dense_vector", datatype=DataType.FLOAT_VECTOR, dim=dense_dim
+        )
+        schema.add_field(field_name="doc_id", datatype=DataType.VARCHAR, max_length=64)
+        schema.add_field(
+            field_name="chunk_id", datatype=DataType.VARCHAR, max_length=64
+        )
+        schema.add_field(field_name="metadata", datatype=DataType.JSON)
+
+        functions = Function(
+            name="bm25",
+            function_type=FunctionType.BM25,
+            input_field_names=["content"],
+            output_field_names="sparse_vector",
+        )
+
+        schema.add_function(functions)
+
+        index_params = MilvusClient.prepare_index_params()
+        index_params.add_index(
+            field_name="sparse_vector",
+            index_type="SPARSE_INVERTED_INDEX",
+            metric_type="BM25",
+        )
+        index_params.add_index(
+            field_name="dense_vector", index_type="FLAT", metric_type="IP"
+        )
+        try:
+            self.client.create_collection(
+                collection_name=self.collection_name,
+                schema=schema,
+                index_params=index_params,
+            )
+            return "create_collection_success"
+        except Exception as e:
+            logger.error(f"创建{self.collection_name}数据库失败:{e}")
+            return "create_collection_error"
+
+    def insert_data(self, chunk, metadata):
+        logger.info("准备插入数据")
+        with torch.no_grad():
+            embedding = self.embedding_function([chunk])
+        logger.info("获取文本的向量信息。")
+        if isinstance(embedding, dict) and "dense" in embedding:
+            # bge embedding 获取embedding的方式
+            dense_vec = embedding["dense"][0]
+        else:
+            dense_vec = embedding[0]
+        
+        try:
+            self.client.insert(
+                self.collection_name, {"dense_vector": dense_vec, **metadata}
+            )
+            logger.info("插入一条数据成功。")
+            return True, "success"
+        except Exception as e:
+            doc_id = metadata.get("doc_id")
+            logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
+            self.delete_by_doc_id(doc_id=doc_id)
+            return False, str(e)
+        
+    def batch_insert_data(self, chunks, metadatas):
+        logger.info("准备插入数据")
+        embedding_lists = self.embedding_function.encode_documents(chunks)
+        logger.info("获取文本的向量信息。")
+        record_lists = []
+        for embedding, metadata in zip(embedding_lists, metadatas):
+            if isinstance(embedding, dict) and "dense" in embedding:
+                # bge embedding 获取embedding的方式
+                dense_vec = embedding["dense"][0]
+            else:
+                dense_vec = embedding.tolist()
+
+            # if hasattr(dense_vec, 'tolist'):
+            #     dense_vec = dense_vec.tolist()
+
+            # logger.info(f"向量维度:{dense_vec}")
+            
+            # if isinstance(dense_vec, (float, int)):
+            #     dense_vec = [dense_vec]
+            
+            # if isinstance(dense_vec, np.float32):
+            #     dense_vec = [float(dense_vec)]
+            
+            record = {"dense_vector": dense_vec}
+            record.update(metadata)
+            record_lists.append(record)
+        
+        try:
+            self.client.insert(
+                self.collection_name, record_lists
+            )
+            logger.info("插入数据成功。")
+            return True, "success"
+        except Exception as e:
+            doc_id = metadata.get("doc_id")
+            logger.error(f"处理文档:{doc_id},插入数据出错:{e}")
+            self.delete_by_doc_id(doc_id=doc_id)
+            return False, str(e)
+                
+
+    def search(self, query: str, k: int = 20, mode="hybrid"):
+
+        output_fields = [
+            "content",
+            "doc_id",
+            "chunk_id",
+            "metadata",
+        ]
+        if mode in ["dense", "hybrid"]:
+            with torch.no_grad():
+                embedding = self.embedding_function([query])
+            if isinstance(embedding, dict) and "dense" in embedding:
+                dense_vec = embedding["dense"][0]
+            else:
+                dense_vec = embedding[0]
+
+        if mode == "sparse":
+            results = self.client.search(
+                collection_name=self.collection_name,
+                data=[query],
+                anns_field="sparse_vector",
+                limit=k,
+                output_fields=output_fields,
+            )
+        elif mode == "dense":
+            results = self.client.search(
+                collection_name=self.collection_name,
+                data=[dense_vec],
+                anns_field="dense_vector",
+                limit=k,
+                output_fields=output_fields,
+            )
+        elif mode == "hybrid":
+            full_text_search_params = {"metric_type": "BM25"}
+            full_text_search_req = AnnSearchRequest(
+                [query], "sparse_vector", full_text_search_params, limit=k
+            )
+
+            dense_search_params = {"metric_type": "IP"}
+            dense_req = AnnSearchRequest(
+                [dense_vec], "dense_vector", dense_search_params, limit=k
+            )
+
+            results = self.client.hybrid_search(
+                self.collection_name,
+                [full_text_search_req, dense_req],
+                ranker=RRFRanker(),
+                limit=k,
+                output_fields=output_fields,
+            )
+        else:
+            raise ValueError("Invalid mode")
+        return [
+            {
+                "doc_id": doc["entity"]["doc_id"],
+                "chunk_id": doc["entity"]["chunk_id"],
+                "content": doc["entity"]["content"],
+                "metadata": doc["entity"]["metadata"],
+                "score": doc["distance"],
+            }
+            for doc in results[0]
+        ]
+    
+    def query_filter(self, doc_id, filter_field):
+        # doc id 文档id,content中包含 filter_field 字段的
+        query_output_field = [
+            "content",
+            "chunk_id",
+            "doc_id",
+            "metadata"
+        ]
+        # query_expr = f"doc_id in {doc_id} && content like '%{filter_field}%'"
+        # 根据doc_id查询如果有关键词,根据关键词查询,如果没有关键词,只根据doc_id查询
+        if filter_field:
+            query_expr = f"doc_id == '{doc_id}' && content like '%{filter_field}%'"
+        else:
+            query_expr = f"doc_id == '{doc_id}'"
+
+        try:
+            query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
+        except Exception as e:
+            logger.error(f"根据关键词查询数据失败:{e}")
+            query_filter_results = [{"code": 500}]
+        return query_filter_results
+        # for result in query_filter_results:
+        #     print(f"根据doc id 和 field 过滤结果: {result}\n\n")
+    
+    def query_chunk_id(self, chunk_id):
+        # chunk id,查询切片
+        query_output_field = [
+            "content",
+            "doc_id",
+            "chunk_id",
+            # "metadata"
+        ]
+        query_expr = f"chunk_id == '{chunk_id}'"
+        try:
+            query_filter_results = self.client.query(collection_name=self.collection_name, filter=query_expr, output_fields=query_output_field)
+        except Exception as e:
+            logger.info(f"根据chunk id 查询出错:{e}")
+            query_filter_results = [{"code": 500}]
+        return query_filter_results
+
+    def update_data(self, chunk_id, chunk):
+        # 根据chunk id查询对应的信息,
+        chunk_expr = f"chunk_id == '{chunk_id}'"
+        chunk_output_fields = [
+            "pk",
+            "doc_id",
+            "chunk_id",
+            "metadata"
+        ]
+        try:
+            chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
+            # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
+        except Exception as e:
+            logger.error(f"更新切片数据时查询失败:{e}")
+            return "update_query_error", ""
+        if not chunk_results:
+            logger.info(f"根据{chunk_id}未在向量库中查询到对应数据,无法更新数据")
+            return "update_query_no_result", ""
+        with torch.no_grad():
+            embedding = self.embedding_function([chunk])
+        if isinstance(embedding, dict) and "dense" in embedding:
+            # bge embedding 获取embedding的方式
+            dense_vec = embedding["dense"][0]
+        else:
+            dense_vec = embedding[0]
+        chunk_dict = chunk_results[0]
+        metadata = chunk_dict.get("metadata")
+        old_chunk_len = metadata.get("chunk_len")
+        chunk_len = len(chunk)
+        metadata["chunk_len"] = chunk_len
+        chunk_dict["content"] = chunk
+        chunk_dict["dense_vector"] = dense_vec
+        chunk_dict["metadata"] = metadata
+        try:
+            update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
+            logger.info(f"更新返回的数据:{update_res}")
+            return "update_success", chunk_len - old_chunk_len
+        except Exception as e:
+            logger.error(f"更新数据时出错:{e}")
+            return "update_error", ""
+            
+    def update_data(self, chunk_id, chunk):
+        # 根据chunk id查询对应的信息,
+        chunk_expr = f"chunk_id == '{chunk_id}'"
+        chunk_output_fields = [
+            "pk",
+            "doc_id",
+            "chunk_id",
+            "metadata"
+        ]
+        try:
+            chunk_results = self.client.query(collection_name=self.collection_name, filter=chunk_expr, output_fields=chunk_output_fields)
+            # logger.info(f"{chunk_id}更新切片的信息:{chunk_results}")
+        except Exception as e:
+            logger.error(f"更新切片数据时查询失败:{e}")
+            return "update_query_error", ""
+        if not chunk_results:
+            logger.info(f"根据{chunk_id}未查询到对应数据,无法更新数据")
+            return "update_query_no_result", ""
+        with torch.no_grad():
+            embedding = self.embedding_function([chunk])
+        if isinstance(embedding, dict) and "dense" in embedding:
+            # bge embedding 获取embedding的方式
+            dense_vec = embedding["dense"][0]
+        else:
+            dense_vec = embedding[0]
+        chunk_dict = chunk_results[0]
+        metadata = chunk_dict.get("metadata")
+        old_chunk_len = metadata.get("chunk_len")
+        chunk_len = len(chunk)
+        metadata["chunk_len"] = chunk_len
+        chunk_dict["content"] = chunk
+        chunk_dict["dense_vector"] = dense_vec
+        chunk_dict["metadata"] = metadata
+        try:
+            update_res = self.client.upsert(collection_name=self.collection_name, data=[chunk_dict])
+            logger.info(f"更新返回的数据:{update_res}")
+            return "update_success", chunk_len - old_chunk_len
+        except Exception as e:
+            logger.error(f"更新数据时出错:{e}")
+            return "update_error", ""
+        
+    def delete_collection(self, collection):
+        try:
+            self.client.drop_collection(collection_name=collection)
+            return "delete_collection_success"
+        except Exception as e:
+            logger.error(f"删除{collection}失败,出错原因:{e}")
+            return "delete_collection_error"
+
+    def delete_by_chunk_id(self, chunk_id:str = None):
+        # 根据文档id查询主键值 milvus只支持主键删除
+        expr = f"chunk_id == '{chunk_id}'"
+        try:
+            results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk","metadata"])  # 获取主键 id
+            logger.info(f"根据切片id:{chunk_id},查询的数据:{results}")
+        except Exception as e:
+            logger.error(f"根据切片id查询主键失败:{e}")
+            return "delete_query_error", []
+        if not results:
+            print(f"No data found for chunk id: {chunk_id}")
+            # return "delete_no_result", []
+            return "delete_success", 0
+        
+        # 提取主键值
+        primary_keys = [result["pk"] for result in results]
+        chunk_len = [result["metadata"]["chunk_len"] for result in results]
+        logger.info(f"获取到的主键信息:{primary_keys}")
+        
+        # 执行删除操作
+        expr_delete = f"pk in {primary_keys}"  # 构造删除表达式
+        try:
+            delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
+            logger.info(f"Deleted data with chunk_id: {delete_res}")
+            return "delete_success", chunk_len
+        except Exception as e:
+            logger.error(f"删除数据失败:{e}")
+            return "delete_error", []
+        
+    def delete_by_doc_id(self, doc_id:str =None):
+        # 根据文档id查询主键值 milvus只支持主键删除
+        expr = f"doc_id == '{doc_id}'"
+        try:
+            results = self.client.query(collection_name=self.collection_name, filter=expr, output_fields=["pk"])  # 获取主键 id
+        except Exception as e:
+            logger.error(f"根据切片id查询主键失败:{e}")
+            return "delete_query_error"
+        if not results:
+            print(f"No data found for doc_id: {doc_id}")
+            return "delete_no_result"
+        
+        # 提取主键值
+        primary_keys = [result["pk"] for result in results]
+        logger.info(f"获取到的主键信息:{primary_keys}")
+        
+        # 执行删除操作
+        expr_delete = f"pk in {primary_keys}"  # 构造删除表达式
+        try:
+            delete_res = self.client.delete(collection_name=self.collection_name, filter=expr_delete)
+            logger.info(f"Deleted data with doc_id: {delete_res}")
+            return "delete_success"
+        except Exception as e:
+            logger.error(f"删除数据失败:{e}")
+            return "delete_error"
+        
+    
+# 测试
+# def parse_json_to_schema_data_format():
+#     sql_text_list = load_sql_query_ddl_info()
+#     docs = []
+#     for sql_info in sql_text_list:
+#         sql_text = sql_info.get("sql_text")
+#         source = ",".join(sql_info.get("table_list", []))
+#         source = sql_info.get("source") if not source else source
+#         sql = sql_info.get("sql")
+#         ddl = sql_info.get("table_ddl")
+#         metadata = {"source": source, "sql": sql, "ddl": ddl}
+#         text_list = sql_info.get("sim_sql_text_list", [])
+#         text_list.append(sql_text)
+#         doc_id = str(uuid4())
+#         for text in text_list:
+#             chunk_id = str(uuid4())
+#             insert_dict = {
+#                 "content": text,
+#                 "doc_id": doc_id,
+#                 "chunk_id": chunk_id,
+#                 "metadata": metadata
+#             }
+#             docs.append(insert_dict)
+    
+#     return docs
+
+# def insert_data_to_milvus(standard_retriever):
+#     sql_dataset = parse_json_to_schema_data_format()
+#     standard_retriever.build_collection()
+#     for sql_dict in sql_dataset:
+#         text = sql_dict["content"]
+#         standard_retriever.insert_data(text, sql_dict)
+
+def main():
+    # dense_ef = BGEM3EmbeddingFunction()
+    embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/"
+    sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
+    standard_retriever = HybridRetriever(
+        uri="http://localhost:19530",
+        collection_name="milvus_hybrid",
+        dense_embedding_function=sentence_transformer_ef,
+    )
+    # 插入数据
+    # insert_data_to_milvus(standard_retriever)
+
+    # 查询 混合检索:hybrid ,稀疏检索sparse,向量检索:dense
+    # results = standard_retriever.search("查一下加班情况", mode="hybrid", k=3)
+    # model=sparse 稀疏检索
+    # print(f"稀疏检索结果:{results}")
+
+    # model=dense 向量检索
+    # print(f"向量检索结果:{results}")
+
+    # model=hybrid
+    # print(f"向量检索结果:{results}")
+
+    # 根据doc id删除数据
+    # delete_start_time = time.time()
+    # doc_id = "e72ebf78-6c0d-410b-8fbb-9a2057673064"
+    # standard_retriever.delete_data(doc_id=doc_id)
+    # delete_end_time = time.time()
+    # print(f"删除耗时:{delete_end_time-delete_start_time}")
+
+    # 根据chunk_id 更新数据
+    # update_start_time = time.time()
+    # chunk = "查询一下员工加班的情况"
+    # chunk_id = "a5e8dded-f5a7-4a1f-92cd-82fa8113b418"
+    # standard_retriever.update_data(chunk_id, chunk)
+    # update_end_time = time.time()
+    # print(f"更新数据的时间:{update_end_time-update_start_time}")
+
+    # 根据doc id 和关键字查询
+    query_start_time = time.time()
+    filter_field = "加班"
+    doc_id = ["7b73ae0b-db97-4315-ba71-783fe7a69c61", "96bbe5a8-5fcf-4769-8343-938acb8735bd"]
+    standard_retriever.query_filter(doc_id, filter_field)
+    query_end_time = time.time()
+    print(f"关键字搜索数据的时间:{query_end_time-query_start_time}")
+
+
+if __name__=="__main__":
+    main()

+ 146 - 0
rag_server.py

@@ -0,0 +1,146 @@
+# 请求的入口
+
+from fastapi import FastAPI, File, UploadFile, Form, Request, Response, WebSocket, WebSocketDisconnect, Depends, APIRouter, Body
+from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
+from fastapi.middleware.cors import CORSMiddleware
+from sse_starlette import EventSourceResponse
+import uvicorn
+from utils.get_logger import setup_logger
+from rag.vector_db.milvus_vector import HybridRetriever
+from response_info import generate_message, generate_response
+from rag.db import MilvusOperate
+from rag.file_process import ParseFile
+from rag.documents_process import ProcessDocuments
+from rag.chat_message import ChatRetrieverRag
+
+
+logger = setup_logger(__name__)
+app = FastAPI()
+
+# 设置跨域
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+@app.get("/health")
+async def health_check():
+    return {"status": "healthy"}
+
+@app.post("/upload_knowledge")
+async def upload_file_to_db(file_json: dict):
+    logger.info(f"上传文件请求参数:{file_json}")
+    parse_file = ProcessDocuments(file_json)
+    resp = await parse_file.process_documents(file_json)
+    # parse_file = ParseFile(file_json)
+    # resp = parse_file.save_file_to_db()
+
+    logger.info(f"上传文件响应结果:{resp}")
+    return JSONResponse(resp)
+
+
+# @app.post("/network/search")
+# async def chat_with_rag(request: Request, chat_json:dict):
+#     retriever = ChatRetrieverRag(chat_json)
+
+#     return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
+
+@app.post("/rag/chat")
+async def chat_with_rag(request: Request, chat_json:dict):
+    retriever = ChatRetrieverRag(chat_json)
+
+    return EventSourceResponse(retriever.generate_event(chat_json, request), ping=300)
+
+@app.post("/rag/query")
+async def generate_query(request: Request, query_json:dict):
+    logger.info(f"请求参数:{query_json}")
+    relevant_query = ChatRetrieverRag(query_json)
+    relevant_json = await relevant_query.generate_relevant_query(query_json)
+
+    return JSONResponse(relevant_json)
+
+@app.get("/rag/slice/search/{chat_id}")
+async def generate_query(request: Request, chat_id:str = None):
+    chat = ChatRetrieverRag(chat_id=chat_id)
+    chunk_json = await chat.search_slice()
+
+    return JSONResponse(chunk_json)
+
+
+@app.delete("/rag/delete_slice/{slice_id}/{knowledge_id}/{document_id}")
+async def delete_by_chunk_id(slice_id:str=None, knowledge_id:str=None, document_id:str=None):
+    logger.info(f"删除切片接口中,知识库:{knowledge_id}, 切片id:{slice_id}")
+    resp = MilvusOperate(collection_name=knowledge_id)._delete_by_chunk_id(slice_id, knowledge_id, document_id)
+    logger.info(f"删除切片信息的结果:{resp}")
+    return JSONResponse(resp)
+
+@app.delete("/rag/delete_doc/{doc_id}/{knowledge_id}")
+async def delete_by_doc_id(doc_id:str=None, knowledge_id:str=None):
+    logger.info(f"删除文档id接口,知识库:{knowledge_id}, 文档id:{doc_id}")
+    resp = MilvusOperate(collection_name=knowledge_id)._delete_by_doc_id(doc_id=doc_id)
+    logger.info(f"删除文档的结果:{resp}")
+    return JSONResponse(resp)
+
+@app.put("/rag/update_slice")
+async def put_by_id(slice_json:dict):
+    logger.info(f"更新切片信息的请求参数:{slice_json}")
+    collection_name = slice_json.get("knowledge_id")
+    resp = MilvusOperate(collection_name=collection_name)._put_by_id(slice_json)
+
+    logger.info(f"更新切片信息的结果:{resp}")
+    return JSONResponse(resp)
+
+@app.post("/rag/insert_slice")
+async def insert_slice_text(slice_json:dict):
+    logger.info(f"新增切片信息的请求参数:{slice_json}")
+    collection_name = slice_json.get("knowledge_id")
+    resp = MilvusOperate(collection_name=collection_name)._insert_slice(slice_json)
+
+    logger.info(f"新增切片信息的结果:{resp}")
+    return JSONResponse(resp)
+
+@app.get("/rag/search/{knowledge_id}/{slice_id}")
+async def search_by_doc_id(knowledge_id:str=None, slice_id:str=None):
+    # 根据切片id查询切片信息
+    # print(f"知识库:{knowledge_id}, 切片:{slice_id}")
+    logger.info(f"根据切片id查询的数据库名:{knowledge_id},切片id:{slice_id}")
+    collection_name = knowledge_id  # 根据传过来的id处理对应知识库
+    resp = MilvusOperate(collection_name=collection_name)._search_by_chunk_id(slice_id)
+
+    logger.info(f"根据切片id查询结果:{resp}")
+    return JSONResponse(resp)
+
+@app.post("/rag/search_word")
+async def search_by_key_word(search_json:dict):
+    # 根据doc_id 查询切片列表信息
+    collection_name = search_json.get("knowledge_id")
+    logger.info(f"根据关键字请求的参数:{search_json}")
+    resp = MilvusOperate(collection_name=collection_name)._search_by_key_word(search_json)
+
+    logger.info(f"根据关键字查询的结果:{resp}")
+    return JSONResponse(resp)
+
+@app.delete("/rag/delete_knowledge/{knowledge_id}")
+async def delete_collection(knowledge_id: str = None):
+    logger.info(f"删除数据库请求的参数:{knowledge_id}")
+    resp = MilvusOperate(collection_name=knowledge_id)._delete_collection()
+    logger.info(f"删除向量库结果:{resp}")
+
+    return JSONResponse(resp)
+
+@app.post("/rag/create_collection")
+async def create_collection(collection: dict):
+    collection_name = collection.get("knowledge_id")
+    embedding_name = collection.get("embedding_id")
+    logger.info(f"创建向量库的库名:{collection_name},向量名称:{embedding_name}")
+    resp = MilvusOperate(collection_name=collection_name, embedding_name=embedding_name)._create_collection()
+    logger.info(f"创建向量库结果:{resp}")
+    
+    return JSONResponse(resp)
+
+
+if __name__ == "__main__":
+    uvicorn.run(app, host="0.0.0.0", port=18079)

+ 104 - 0
requirements.txt

@@ -0,0 +1,104 @@
+aiohappyeyeballs==2.6.1
+aiohttp==3.11.15
+aiosignal==1.3.2
+annotated-types==0.7.0
+anyio==4.9.0
+argon2-cffi==23.1.0
+argon2-cffi-bindings==21.2.0
+attrs==25.3.0
+certifi==2025.1.31
+cffi==1.17.1
+charset-normalizer==3.4.1
+click==8.1.8
+colorama==0.4.6
+coloredlogs==15.0.1
+cryptography==44.0.2
+dataclasses-json==0.6.7
+distro==1.9.0
+fastapi==0.115.12
+filelock==3.13.1
+flatbuffers==25.2.10
+frozenlist==1.5.0
+fsspec==2024.6.1
+greenlet==3.1.1
+grpcio==1.67.1
+h11==0.14.0
+httpcore==1.0.7
+httpx==0.28.1
+httpx-sse==0.4.0
+huggingface-hub==0.30.1
+humanfriendly==10.0
+idna==3.10
+Jinja2==3.1.4
+jiter==0.9.0
+joblib==1.4.2
+jsonpatch==1.33
+jsonpointer==3.0.0
+langchain==0.3.22
+langchain-community==0.3.18
+langchain-core==0.3.49
+langchain-text-splitters==0.3.7
+langsmith==0.3.21
+MarkupSafe==2.1.5
+marshmallow==3.26.1
+milvus-model==0.2.12
+minio==7.2.15
+mpmath==1.3.0
+multidict==6.3.0
+mypy-extensions==1.0.0
+mysql-connector-python==9.2.0
+networkx==3.3
+numpy==1.26.4
+onnxruntime==1.21.0
+openai==1.72.0
+orjson==3.10.16
+packaging==24.2
+pandas==2.2.3
+pdfminer.six==20250327
+pdfplumber==0.11.6
+pillow==11.0.0
+propcache==0.3.1
+protobuf==6.30.2
+pycparser==2.22
+pycryptodome==3.22.0
+pydantic==2.11.1
+pydantic-settings==2.8.1
+pydantic_core==2.33.0
+pymilvus==2.5.4
+PyMuPDF==1.25.5
+pypdfium2==4.30.1
+pyreadline3==3.5.4
+python-dateutil==2.9.0.post0
+python-dotenv==1.1.0
+pytz==2025.2
+PyYAML==6.0.2
+regex==2024.11.6
+requests==2.32.3
+requests-toolbelt==1.0.0
+safetensors==0.5.3
+scikit-learn==1.6.1
+scipy==1.15.2
+sentence-transformers==4.0.1
+six==1.17.0
+sniffio==1.3.1
+SQLAlchemy==2.0.40
+sse-starlette==2.2.1
+starlette==0.46.1
+sympy==1.13.1
+tenacity==9.0.0
+threadpoolctl==3.6.0
+tokenizers==0.21.1
+torch==2.5.0+cu124
+torchaudio==2.5.0+cu124
+torchvision==0.20.0+cu124
+tqdm==4.67.1
+transformers==4.50.3
+typing-inspect==0.9.0
+typing-inspection==0.4.0
+typing_extensions==4.12.2
+tzdata==2025.2
+ujson==5.10.0
+urllib3==2.3.0
+uvicorn==0.34.0
+yarl==1.18.3
+zstandard==0.23.0

+ 85 - 0
response_info.py

@@ -0,0 +1,85 @@
+database_delete_update_response_mapping = {
+    "delete_query_error": ("根据id查询出错", 500),
+    "delete_no_result": ("向量库中未查询到对应的数据", 200),
+    "delete_error": ("根据id删除数据失败", 500),
+    "delete_success": ("删除成功", 200),
+    "update_query_error": ("更新数据出错", 500),
+    "update_query_no_result": ("未查询到需要更新的数据", 200),
+    "update_error": ("更新数据时出错", 500),
+    "update_success": ("更新成功", 200),
+    "delete_collection_error": ("删除数据库失败", 500),
+    "delete_collection_success": ("删除数据库成功", 200),
+    "create_collection_error": ("创建数据库失败", 500),
+    "create_collection_success": ("创建数据库成功", 200),
+    "insert_success": ("插入切片成功", 200),
+    "insert_error": ("插入切片失败", 500),
+}
+
+
+def generate_message(result):
+    message, code = database_delete_update_response_mapping.get(result)
+
+    response_dict = {
+        "code": code,
+        "message": message
+    }
+    return response_dict
+
+def generate_response(result, page_num=0, page_size=10):
+    if page_num:
+        if not result:
+            resp_dict = {"code": 200, "rows": [], "total": 0}
+    
+        elif "code" in result[0]:
+            resp_dict = {}
+            resp_dict.update(result[0])
+            resp_dict["rows"] = []
+            resp_dict["message"] = "查询向量库出错"
+            resp_dict["total"] = 0
+
+            # return resp_dict
+        else:
+            rows = []
+            total_len = len(result)
+            skip = int(page_num-1) * int(page_size)
+            iter_result = result[skip:skip+page_size]
+            for i in iter_result:
+                d = {}
+                d["slice_id"] = i.get("chunk_id")
+                d["document_id"] = i.get("doc_id")
+                d["slice_text"] = i.get("content")
+                d["slice_char_len"] = i.get("metadata").get("chunk_len")
+                rows.append(d)
+            resp_dict = {
+                "code": 200,
+                "rows": rows,
+                "total": total_len
+            }
+            
+        return resp_dict
+    else:
+        if not result:
+            resp_dict = {"code": 200, "data": {}}
+    
+        elif "code" in result[0]:
+            resp_dict = {}
+            resp_dict.update(result[0])
+            resp_dict["data"] = {}
+            resp_dict["message"] = "查询向量库出错"
+
+            # return resp_dict
+        else:
+            d = {}
+            for i in result:
+                d["slice_id"] = i.get("chunk_id")
+                d["document_id"] = i.get("doc_id")
+                d["slice_text"] = i.get("content")
+                # rows.append(d)
+            resp_dict = {
+                "code": 200,
+                "data": d,
+                # "total": total_len
+            }
+
+        return resp_dict
+    

+ 0 - 0
utils/__init__.py


BIN
utils/__pycache__/__init__.cpython-310.pyc


BIN
utils/__pycache__/__init__.cpython-311.pyc


BIN
utils/__pycache__/get_logger.cpython-310.pyc


BIN
utils/__pycache__/get_logger.cpython-311.pyc


BIN
utils/__pycache__/upload_file_to_oss.cpython-310.pyc


BIN
utils/__pycache__/upload_file_to_oss.cpython-311.pyc


+ 40 - 0
utils/get_logger.py

@@ -0,0 +1,40 @@
+import logging
+from logging.handlers import TimedRotatingFileHandler
+import os
+
+def setup_logger(file_name):
+    # 获取一个日志记录器
+    logger = logging.getLogger(file_name)
+    
+    # 设置日志级别
+    logger.setLevel(logging.DEBUG)  # 设置最低的日志级别为DEBUG
+
+    # 创建一个文件处理器,并设置日志文件的路径
+    # file_handler = logging.FileHandler(f"./logs/{file_name}.log", encoding="utf-8")
+    log_file_path = os.path.join("./logs", f"{file_name}.log")
+    file_handler = TimedRotatingFileHandler(
+        filename=log_file_path,
+        when="midnight",       # 每天午夜轮转
+        interval=1,            # 每隔1天轮转一次
+        encoding="utf-8"
+    )
+    
+    # 设置文件处理器的日志级别
+    file_handler.setLevel(logging.DEBUG)
+    
+    # 创建一个控制台处理器(可选,用于同时输出到控制台)
+    console_handler = logging.StreamHandler()
+    console_handler.setLevel(logging.DEBUG)
+    
+    # 创建一个格式化器,并设置日志条目的格式
+    formatter = logging.Formatter('%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s')
+    
+    # 将格式化器添加到处理器
+    file_handler.setFormatter(formatter)
+    console_handler.setFormatter(formatter)
+    
+    # 将处理器添加到日志记录器
+    logger.addHandler(file_handler)
+    logger.addHandler(console_handler)
+    
+    return logger

BIN
utils/simsun.ttc


+ 75 - 0
utils/upload_file_to_oss.py

@@ -0,0 +1,75 @@
+from minio import Minio
+from minio.error import S3Error
+from minio.deleteobjects import DeleteObject
+from config import minio_config
+from utils.get_logger import setup_logger
+
+logger = setup_logger(__name__)
+
+class UploadMinio():
+    def __init__(self):
+        self.minio_endpoint = minio_config.get("minio_endpoint")
+        self.minio_access_key = minio_config.get("minio_access_key")
+        self.minio_secret_key = minio_config.get("minio_secret_key")
+        self.minio_bucket = minio_config.get("minio_bucket")
+        self.flag = minio_config.get("flag")
+        self.minio_client = Minio(self.minio_endpoint,
+                    access_key=self.minio_access_key,
+                    secret_key=self.minio_secret_key,
+                    cert_check=False,
+                    # http_client=http_client,
+                    secure=self.flag)  # 如果使用https,请将secure设置为True
+
+    def upload_file(self, local_file_path, minio_bucket_file):
+        try:
+            # 使用put_object方法将文件上传到MinIO
+            self.minio_client.fput_object(self.minio_bucket, minio_bucket_file, local_file_path)
+        
+            logger.info(f"File {local_file_path} uploaded to MinIO successfully.")
+            return True
+        
+        except S3Error as e:
+            logger.error(f"Error uploading file to MinIO: {e}")
+            return False
+
+        except Exception as e:
+            logger.error(f"报错信息:{e}")
+            return False
+        
+    def upload_io_bytes(self, minio_bucket_file, bytes_obj):
+        try:
+            # 使用put_object方法将文件上传到MinIO
+            self.minio_client.put_object(self.minio_bucket,minio_bucket_file, bytes_obj, bytes_obj.tell())
+        
+            logger.info(f"File {minio_bucket_file} uploaded to MinIO successfully.")
+            return True
+        
+        except S3Error as e:
+            logger.error(f"Error uploading file to MinIO: {e}")
+            return False
+
+        except Exception as e:
+            logger.error(f"报错信息:{e}")
+            return False
+        
+    def delete_doc_id_images(self, knowledge_id, doc_id):
+        try:
+            objects = self.minio_client.list_objects(self.minio_bucket, prefix=f"pdf/{knowledge_id}/{doc_id}", recursive=True)
+
+            objects_to_delete = []
+            for obj in objects:
+                objects_to_delete.append(DeleteObject(obj.object_name))
+                # self.minio_client.remove_object(self.minio_bucket, obj.object_name)
+            error_responses = self.minio_client.remove_objects(self.minio_bucket, objects_to_delete)
+            for error in error_responses:
+                if error:
+                    delete_flag = False 
+                    logger.info("Deletion Error: {}".format(error))
+                    break
+            else:
+                logger.info("Sussess")
+                delete_flag = True
+            return delete_flag
+        except Exception as e:
+            logger.error(f"删除minio图片出错:{e}")
+            return False