| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184 |
- import requests
- from fastapi import HTTPException
- from typing import List, Dict
- import os
- from uuid import uuid1
- import uuid
- from rag.document_load.pdf_load import PDFLoader
- from rag.document_load.txt_load import TextLoad
- from langchain_text_splitters import RecursiveCharacterTextSplitter
- from rag.db import MilvusOperate, MysqlOperate
- import httpx
- import time
- from utils.get_logger import setup_logger
- logger = setup_logger(__name__)
- file_dict = {
- "pdf": PDFLoader,
- # "txt": TextLoad
- }
- class ParseFile:
- def __init__(self, file_json):
- self.file_json = file_json
- self.file_name = self.file_json.get("name")
- # self.file_url = self.file_json
- self.file_list = self.file_json.get("name").split(".")
- file_type = self.file_list[1]
- self.flag = self.file_json.get("flag")
- self.knowledge_id = self.file_json.get("knowledge_id")
- self.doc_id = self.file_json.get("document_id")
- self.save_file_to_tmp()
- self.load_file = file_dict.get(file_type, PDFLoader)(self.file_json)
- self.mysql_client = MysqlOperate()
- self.milvus_client = MilvusOperate(collection_name=self.knowledge_id)
- def save_file_to_tmp(self):
- # 远程文件存到本地处理
- url = self.file_json.get("url")
- know_path = "./tmp_file" + f"/{self.knowledge_id}"
- os.makedirs(know_path, exist_ok=True)
- tmp_file_name = f"./tmp_file/{self.knowledge_id}/{self.file_name}"
- if self.flag == "upload":
- file_response = requests.get(url=url)
- with open(tmp_file_name, "wb") as f:
- f.write(file_response.content)
- elif self.flag == "update":
- if os.path.exists(tmp_file_name):
- pass
- else:
- file_response = requests.get(url=url)
- with open(tmp_file_name, "wb") as f:
- f.write(file_response.content)
- # return file_name
- def file_split(self, file_text):
- split_map = {
- "0": ["\n"],
- "1": ["\n"],
- "2": ["\n"]
- }
- separator_num = self.file_json.get("set_slice")
- slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
- separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
- logger.info(f"文本切分字符:{separator}")
- text_split = RecursiveCharacterTextSplitter(
- separators=separator,
- chunk_size=300,
- chunk_overlap=20,
- length_function=len
- )
- texts = text_split.split_text(file_text)
- return texts
-
- def process_data_to_milvus_schema(self, text_lists):
- """组织数据格式:
- {
- "content": text,
- "doc_id": doc_id,
- "chunk_id": chunk_id,
- "metadata": {"source": file_name},
- }
- """
- # doc_id = self.file_json.get("document_id")
- docs = []
- for i, text in enumerate(text_lists):
- chunk_id = str(uuid1())
- d = {
- "content": text,
- "doc_id": self.doc_id,
- "chunk_id": chunk_id,
- "metadata": {"source": self.file_name, "chunk_index": i+1}
- }
- # d["content"] = text
- # d["doc_id"] = doc_id
- # d["chunk_id"] = chunk_id
- # d["metadata"] = {"source": self.file_name, "chunk_index": i+1}
- docs.append(d)
- return docs
- def save_file_to_db(self):
- # 如果更改切片方式,需要删除对应knowledge id中doc id对应数据
- flag = self.file_json.get("flag")
- if flag == "update":
- # 执行删除操作
- self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
- self.mysql_client.delete_to_slice(doc_id=self.doc_id)
- # self.mysql_client.delete_image_url(doc_id=doc_id)
- file_text_start_time = time.time()
- file_text, image_dict = self.load_file.file2text()
- file_text_end_time = time.time()
- logger.info(f"pdf加载成文本耗时:{file_text_end_time - file_text_start_time}")
- text_lists = self.file_split(file_text)
- file_split_end_time = time.time()
- logger.info(f"文档切分的耗时:{file_split_end_time - file_text_end_time}")
- docs = self.process_data_to_milvus_schema(text_lists)
- logger.info(f"插入milvus的数据:{docs}")
- # doc_id = self.file_json.get("document_id")
- if flag == "upload":
- # 插入到milvus库中
- insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
-
- if insert_milvus_flag:
- # 插入到mysql的slice info数据库中
- insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
- else:
- resp = {"code": 500, "message": insert_milvus_str}
- return resp
- if insert_slice_flag:
- # 插入mysql中的bm_media_replacement表中
- insert_img_flag, insert_mysql_info = self.mysql_client.insert_to_image_url(image_dict, self.knowledge_id, self.doc_id)
- else:
- resp = {"code": 500, "message": insert_mysql_info}
- self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
- return resp
- if insert_img_flag:
- resp = {"code": 200, "message": "文档解析成功"}
-
- else:
- self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
- self.mysql_client.delete_image_url(doc_id=self.doc_id)
- resp = {"code": 500, "message": insert_mysql_info}
- return resp
- elif flag == "update":
- # 插入到milvus库中
- insert_milvus_start_time = time.time()
- insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
- # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists)
- insert_milvus_end_time = time.time()
- logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}")
- if insert_milvus_flag:
- # 插入到mysql的slice info数据库中
- insert_mysql_start_time = time.time()
- insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
- insert_mysql_end_time = time.time()
- logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}")
- else:
- resp = {"code": 500, "message": insert_milvus_str}
- return resp
-
- if insert_slice_flag:
- resp = {"code": 200, "message": "切片修改成功"}
-
- else:
- self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
- resp = {"code":500, "message": insert_mysql_info}
- return resp
|