import requests from fastapi import HTTPException from typing import List, Dict import os from uuid import uuid1 import uuid from rag.document_load.pdf_load import PDFLoader from rag.document_load.txt_load import TextLoad from langchain_text_splitters import RecursiveCharacterTextSplitter from rag.db import MilvusOperate, MysqlOperate import httpx import time from utils.get_logger import setup_logger logger = setup_logger(__name__) file_dict = { "pdf": PDFLoader, # "txt": TextLoad } class ParseFile: def __init__(self, file_json): self.file_json = file_json self.file_name = self.file_json.get("name") # self.file_url = self.file_json self.file_list = self.file_json.get("name").split(".") file_type = self.file_list[1] self.flag = self.file_json.get("flag") self.knowledge_id = self.file_json.get("knowledge_id") self.doc_id = self.file_json.get("document_id") self.save_file_to_tmp() self.load_file = file_dict.get(file_type, PDFLoader)(self.file_json) self.mysql_client = MysqlOperate() self.milvus_client = MilvusOperate(collection_name=self.knowledge_id) def save_file_to_tmp(self): # 远程文件存到本地处理 url = self.file_json.get("url") know_path = "./tmp_file" + f"/{self.knowledge_id}" os.makedirs(know_path, exist_ok=True) tmp_file_name = f"./tmp_file/{self.knowledge_id}/{self.file_name}" if self.flag == "upload": file_response = requests.get(url=url) with open(tmp_file_name, "wb") as f: f.write(file_response.content) elif self.flag == "update": if os.path.exists(tmp_file_name): pass else: file_response = requests.get(url=url) with open(tmp_file_name, "wb") as f: f.write(file_response.content) # return file_name def file_split(self, file_text): split_map = { "0": ["\n"], "1": ["\n"], "2": ["\n"] } separator_num = self.file_json.get("set_slice") slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n") separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value] logger.info(f"文本切分字符:{separator}") text_split = RecursiveCharacterTextSplitter( separators=separator, chunk_size=300, chunk_overlap=20, length_function=len ) texts = text_split.split_text(file_text) return texts def process_data_to_milvus_schema(self, text_lists): """组织数据格式: { "content": text, "doc_id": doc_id, "chunk_id": chunk_id, "metadata": {"source": file_name}, } """ # doc_id = self.file_json.get("document_id") docs = [] for i, text in enumerate(text_lists): chunk_id = str(uuid1()) d = { "content": text, "doc_id": self.doc_id, "chunk_id": chunk_id, "metadata": {"source": self.file_name, "chunk_index": i+1} } # d["content"] = text # d["doc_id"] = doc_id # d["chunk_id"] = chunk_id # d["metadata"] = {"source": self.file_name, "chunk_index": i+1} docs.append(d) return docs def save_file_to_db(self): # 如果更改切片方式,需要删除对应knowledge id中doc id对应数据 flag = self.file_json.get("flag") if flag == "update": # 执行删除操作 self.milvus_client._delete_by_doc_id(doc_id=self.doc_id) self.mysql_client.delete_to_slice(doc_id=self.doc_id) # self.mysql_client.delete_image_url(doc_id=doc_id) file_text_start_time = time.time() file_text, image_dict = self.load_file.file2text() file_text_end_time = time.time() logger.info(f"pdf加载成文本耗时:{file_text_end_time - file_text_start_time}") text_lists = self.file_split(file_text) file_split_end_time = time.time() logger.info(f"文档切分的耗时:{file_split_end_time - file_text_end_time}") docs = self.process_data_to_milvus_schema(text_lists) logger.info(f"插入milvus的数据:{docs}") # doc_id = self.file_json.get("document_id") if flag == "upload": # 插入到milvus库中 insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs) if insert_milvus_flag: # 插入到mysql的slice info数据库中 insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id) else: resp = {"code": 500, "message": insert_milvus_str} return resp if insert_slice_flag: # 插入mysql中的bm_media_replacement表中 insert_img_flag, insert_mysql_info = self.mysql_client.insert_to_image_url(image_dict, self.knowledge_id, self.doc_id) else: resp = {"code": 500, "message": insert_mysql_info} self.milvus_client._delete_by_doc_id(doc_id=self.doc_id) return resp if insert_img_flag: resp = {"code": 200, "message": "文档解析成功"} else: self.milvus_client._delete_by_doc_id(doc_id=self.doc_id) self.mysql_client.delete_image_url(doc_id=self.doc_id) resp = {"code": 500, "message": insert_mysql_info} return resp elif flag == "update": # 插入到milvus库中 insert_milvus_start_time = time.time() insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs) # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists) insert_milvus_end_time = time.time() logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}") if insert_milvus_flag: # 插入到mysql的slice info数据库中 insert_mysql_start_time = time.time() insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id) insert_mysql_end_time = time.time() logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}") else: resp = {"code": 500, "message": insert_milvus_str} return resp if insert_slice_flag: resp = {"code": 200, "message": "切片修改成功"} else: self.milvus_client._delete_by_doc_id(doc_id=self.doc_id) resp = {"code":500, "message": insert_mysql_info} return resp