file_process.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import requests
  2. from fastapi import HTTPException
  3. from typing import List, Dict
  4. import os
  5. from uuid import uuid1
  6. import uuid
  7. from rag.document_load.pdf_load import PDFLoader
  8. from rag.document_load.txt_load import TextLoad
  9. from langchain_text_splitters import RecursiveCharacterTextSplitter
  10. from rag.db import MilvusOperate, MysqlOperate
  11. import httpx
  12. import time
  13. from utils.get_logger import setup_logger
  14. logger = setup_logger(__name__)
  15. file_dict = {
  16. "pdf": PDFLoader,
  17. # "txt": TextLoad
  18. }
  19. class ParseFile:
  20. def __init__(self, file_json):
  21. self.file_json = file_json
  22. self.file_name = self.file_json.get("name")
  23. # self.file_url = self.file_json
  24. self.file_list = self.file_json.get("name").split(".")
  25. file_type = self.file_list[1]
  26. self.flag = self.file_json.get("flag")
  27. self.knowledge_id = self.file_json.get("knowledge_id")
  28. self.doc_id = self.file_json.get("document_id")
  29. self.save_file_to_tmp()
  30. self.load_file = file_dict.get(file_type, PDFLoader)(self.file_json)
  31. self.mysql_client = MysqlOperate()
  32. self.milvus_client = MilvusOperate(collection_name=self.knowledge_id)
  33. def save_file_to_tmp(self):
  34. # 远程文件存到本地处理
  35. url = self.file_json.get("url")
  36. know_path = "./tmp_file" + f"/{self.knowledge_id}"
  37. os.makedirs(know_path, exist_ok=True)
  38. tmp_file_name = f"./tmp_file/{self.knowledge_id}/{self.file_name}"
  39. if self.flag == "upload":
  40. file_response = requests.get(url=url)
  41. with open(tmp_file_name, "wb") as f:
  42. f.write(file_response.content)
  43. elif self.flag == "update":
  44. if os.path.exists(tmp_file_name):
  45. pass
  46. else:
  47. file_response = requests.get(url=url)
  48. with open(tmp_file_name, "wb") as f:
  49. f.write(file_response.content)
  50. # return file_name
  51. def file_split(self, file_text):
  52. split_map = {
  53. "0": ["\n"],
  54. "1": ["\n"],
  55. "2": ["\n"]
  56. }
  57. separator_num = self.file_json.get("set_slice")
  58. slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
  59. separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
  60. logger.info(f"文本切分字符:{separator}")
  61. text_split = RecursiveCharacterTextSplitter(
  62. separators=separator,
  63. chunk_size=300,
  64. chunk_overlap=20,
  65. length_function=len
  66. )
  67. texts = text_split.split_text(file_text)
  68. return texts
  69. def process_data_to_milvus_schema(self, text_lists):
  70. """组织数据格式:
  71. {
  72. "content": text,
  73. "doc_id": doc_id,
  74. "chunk_id": chunk_id,
  75. "metadata": {"source": file_name},
  76. }
  77. """
  78. # doc_id = self.file_json.get("document_id")
  79. docs = []
  80. for i, text in enumerate(text_lists):
  81. chunk_id = str(uuid1())
  82. d = {
  83. "content": text,
  84. "doc_id": self.doc_id,
  85. "chunk_id": chunk_id,
  86. "metadata": {"source": self.file_name, "chunk_index": i+1}
  87. }
  88. # d["content"] = text
  89. # d["doc_id"] = doc_id
  90. # d["chunk_id"] = chunk_id
  91. # d["metadata"] = {"source": self.file_name, "chunk_index": i+1}
  92. docs.append(d)
  93. return docs
  94. def save_file_to_db(self):
  95. # 如果更改切片方式,需要删除对应knowledge id中doc id对应数据
  96. flag = self.file_json.get("flag")
  97. if flag == "update":
  98. # 执行删除操作
  99. self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
  100. self.mysql_client.delete_to_slice(doc_id=self.doc_id)
  101. # self.mysql_client.delete_image_url(doc_id=doc_id)
  102. file_text_start_time = time.time()
  103. file_text, image_dict = self.load_file.file2text()
  104. file_text_end_time = time.time()
  105. logger.info(f"pdf加载成文本耗时:{file_text_end_time - file_text_start_time}")
  106. text_lists = self.file_split(file_text)
  107. file_split_end_time = time.time()
  108. logger.info(f"文档切分的耗时:{file_split_end_time - file_text_end_time}")
  109. docs = self.process_data_to_milvus_schema(text_lists)
  110. logger.info(f"插入milvus的数据:{docs}")
  111. # doc_id = self.file_json.get("document_id")
  112. if flag == "upload":
  113. # 插入到milvus库中
  114. insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
  115. if insert_milvus_flag:
  116. # 插入到mysql的slice info数据库中
  117. insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
  118. else:
  119. resp = {"code": 500, "message": insert_milvus_str}
  120. return resp
  121. if insert_slice_flag:
  122. # 插入mysql中的bm_media_replacement表中
  123. insert_img_flag, insert_mysql_info = self.mysql_client.insert_to_image_url(image_dict, self.knowledge_id, self.doc_id)
  124. else:
  125. resp = {"code": 500, "message": insert_mysql_info}
  126. self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
  127. return resp
  128. if insert_img_flag:
  129. resp = {"code": 200, "message": "文档解析成功"}
  130. else:
  131. self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
  132. self.mysql_client.delete_image_url(doc_id=self.doc_id)
  133. resp = {"code": 500, "message": insert_mysql_info}
  134. return resp
  135. elif flag == "update":
  136. # 插入到milvus库中
  137. insert_milvus_start_time = time.time()
  138. insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
  139. # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists)
  140. insert_milvus_end_time = time.time()
  141. logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}")
  142. if insert_milvus_flag:
  143. # 插入到mysql的slice info数据库中
  144. insert_mysql_start_time = time.time()
  145. insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, self.doc_id)
  146. insert_mysql_end_time = time.time()
  147. logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}")
  148. else:
  149. resp = {"code": 500, "message": insert_milvus_str}
  150. return resp
  151. if insert_slice_flag:
  152. resp = {"code": 200, "message": "切片修改成功"}
  153. else:
  154. self.milvus_client._delete_by_doc_id(doc_id=self.doc_id)
  155. resp = {"code":500, "message": insert_mysql_info}
  156. return resp