documents_process.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import aiohttp
  2. import aiofiles
  3. from rag.db import MilvusOperate, MysqlOperate
  4. from rag.document_load.pdf_load import MinerUParsePdf
  5. from rag.document_load.office_load import MinerUParseOffice
  6. from rag.document_load.txt_load import TextLoad
  7. from rag.document_load.image_load import MinerUParseImage
  8. from utils.upload_file_to_oss import UploadMinio
  9. from utils.get_logger import setup_logger
  10. from config import minio_config
  11. import os
  12. import time
  13. from uuid import uuid1
  14. from langchain_text_splitters import RecursiveCharacterTextSplitter
  15. pdf_parse = MinerUParsePdf()
  16. office_parse = MinerUParseOffice()
  17. text_parse = TextLoad()
  18. image_parse = MinerUParseImage()
  19. logger = setup_logger(__name__)
  20. class ProcessDocuments():
  21. def __init__(self, file_json):
  22. self.file_json = file_json
  23. self.knowledge_id = self.file_json.get("knowledge_id")
  24. self.mysql_client = MysqlOperate()
  25. self.minio_client = UploadMinio()
  26. self.milvus_client = MilvusOperate(collection_name=self.knowledge_id)
  27. def _get_file_type(self, name):
  28. if name.endswith(".txt"):
  29. return text_parse
  30. elif name.endswith('.pdf'):
  31. return pdf_parse
  32. elif name.endswith((".doc", ".docx", "ppt", "pptx")):
  33. return office_parse
  34. elif name.endswith((".jpg", "png", "jpeg")):
  35. return image_parse
  36. else:
  37. raise "不支持的文件格式"
  38. async def save_file_temp(self, session, url, name):
  39. down_file_path = "./tmp_file" + f"/{self.knowledge_id}"
  40. # down_file_path = "./tmp_file"
  41. os.makedirs(down_file_path, exist_ok=True)
  42. down_file_name = down_file_path + f"/{name}"
  43. # if os.path.exists(down_file_name):
  44. # pass
  45. # else:
  46. async with session.get(url, ssl=False) as resp:
  47. resp.raise_for_status()
  48. content_length = resp.headers.get('Content-Length')
  49. if content_length:
  50. file_size = int(content_length)
  51. else:
  52. file_size = 0
  53. async with aiofiles.open(down_file_name, 'wb') as f:
  54. async for chunk in resp.content.iter_chunked(1024):
  55. await f.write(chunk)
  56. return down_file_name, file_size
  57. def file_split_by_len(self, file_text):
  58. split_map = {
  59. "0": ["#"], # 按标题段落切片
  60. "1": ["<page>"], # 按页切片
  61. "2": ["\n"] # 按问答对
  62. }
  63. separator_num = self.file_json.get("set_slice")
  64. slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
  65. separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
  66. logger.info(f"文本切分字符:{separator}")
  67. text_split = RecursiveCharacterTextSplitter(
  68. separators=separator,
  69. chunk_size=500,
  70. chunk_overlap=40,
  71. length_function=len
  72. )
  73. texts = text_split.split_text(file_text)
  74. return texts
  75. def split_text(self, file_text):
  76. text_split = RecursiveCharacterTextSplitter(
  77. separators=["\n\n", "\n"],
  78. chunk_size=500,
  79. chunk_overlap=40,
  80. length_function=len
  81. )
  82. texts = text_split.split_text(file_text)
  83. return texts
  84. def split_by_title(self, file_content_list, set_table, doc_id):
  85. # TODO 处理根据标题切分逻辑 图片替换标识符,表格按照set table 0图片,1html数据
  86. text_lists = []
  87. text = ""
  88. image_num = 1
  89. flag_img_info = {}
  90. level_1_text = ""
  91. level_2_text = ""
  92. for i, content_dict in enumerate(file_content_list):
  93. text_type = content_dict.get("type")
  94. content_text = content_dict.get("text")
  95. if text_type == "text":
  96. text_level = content_dict.get("text_level", "")
  97. if text_level == 1:
  98. if not level_1_text:
  99. level_1_text = f"# {content_text}\n"
  100. text += f"# {content_text}\n"
  101. else:
  102. text_lists.append(text)
  103. text = f"# {content_text}\n"
  104. level_1_text = f"# {content_text}\n"
  105. level_2_text = ""
  106. elif text_level == 2:
  107. if not level_2_text:
  108. text += f"## {content_text}\n"
  109. level_2_text = f"## {content_text}\n"
  110. else:
  111. text_lists.append(text)
  112. text = level_1_text + f"## {content_text}\n"
  113. else:
  114. if text_level:
  115. text += text_level*"#" + " " + content_text + "\n"
  116. else:
  117. text += content_text
  118. elif text_type == "table" and set_table == "1":
  119. text += content_dict.get("table_body")
  120. elif text_type in ("image", "table"):
  121. image_path = content_dict.get("img_path")
  122. if not image_path:
  123. continue
  124. image_name = image_path.split("/")[1]
  125. save_image_path = "./tmp_file/images/" + f"/{image_name}"
  126. replace_text = f"【示意图序号_{doc_id}_{image_num}】"
  127. minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
  128. self.minio_client.upload_file(save_image_path, minio_file_path)
  129. minio_url = minio_config.get("minio_url")
  130. minio_bucket = minio_config.get("minio_bucket")
  131. flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
  132. text += replace_text
  133. image_num += 1
  134. if i+1 == len(file_content_list):
  135. text_lists.append(text)
  136. return text_lists, flag_img_info
  137. def split_by_page(self, file_content_list, set_table, doc_id):
  138. # TODO 处理按照页面切分,图片处理成标识符,表格按照set table 0图片,1html数据
  139. text_lists = []
  140. current_page = ""
  141. text = ""
  142. image_num = 1
  143. flag_img_info = {}
  144. for i,content_dict in enumerate(file_content_list):
  145. page_index = content_dict.get("page_idx")
  146. if i == 0:
  147. current_page = page_index
  148. elif page_index != current_page:
  149. text_lists.append(text)
  150. text = ""
  151. current_page = page_index
  152. text_type = content_dict.get("type")
  153. if text_type == "text":
  154. content_text = content_dict.get("text")
  155. text_level = content_dict.get("text_level")
  156. if text_level:
  157. text += "#" * text_level + " " + content_text
  158. else:
  159. text += content_text
  160. elif text_type == "table" and set_table == "1":
  161. text += content_dict.get("table_body")
  162. elif text_type in ("image", "table"):
  163. image_path = content_dict.get("img_path")
  164. image_name = image_path.split("/")[1]
  165. save_image_path = "./tmp_file/images/" + f"/{image_name}"
  166. replace_text = f"【示意图序号_{doc_id}_{image_num}】"
  167. minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
  168. self.minio_client.upload_file(save_image_path, minio_file_path)
  169. minio_url = minio_config.get("minio_url")
  170. minio_bucket = minio_config.get("minio_bucket")
  171. flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
  172. text += replace_text
  173. image_num += 1
  174. if i+1 == len(file_content_list):
  175. text_lists.append(text)
  176. return text_lists, flag_img_info
  177. def split_by_self(self, file_content_list, set_table, slice_value, doc_id):
  178. # TODO 按照自定义的符号切分,图片处理成标识符,表格按照set table 0图片,1html数据,长度控制500以内,超过500切断
  179. logger.info(f"自定义的分隔符:{slice_value}")
  180. text = ""
  181. image_num = 1
  182. flag_img_info = {}
  183. for i, content_dict in enumerate(file_content_list):
  184. text_type = content_dict.get("type")
  185. if text_type == "text":
  186. content_text = content_dict.get("text")
  187. text_level = content_dict.get("text_level")
  188. if text_level:
  189. text += "#" * text_level + " " + content_text
  190. else:
  191. text += content_text
  192. elif text_type == "table" and set_table == "1":
  193. text += content_dict.get("table_body")
  194. elif text_type in ("image", "table"):
  195. image_path = content_dict.get("img_path")
  196. image_name = image_path.split("/")[1]
  197. save_image_path = "./tmp_file/images/" + f"/{image_name}"
  198. replace_text = f"【示意图序号_{doc_id}_{image_num}】"
  199. minio_file_path = f"/pdf/{self.knowledge_id}/{doc_id}/{replace_text}.jpg"
  200. self.minio_client.upload_file(save_image_path, minio_file_path)
  201. minio_url = minio_config.get("minio_url")
  202. minio_bucket = minio_config.get("minio_bucket")
  203. flag_img_info[replace_text] = f"{minio_url}/{minio_bucket}/{minio_file_path}"
  204. text += replace_text
  205. image_num += 1
  206. split_lists = text.split(slice_value)
  207. text_lists = []
  208. for split_text in split_lists:
  209. r = len(split_text)//500
  210. if r >= 1:
  211. for i in range(r+1):
  212. t = split_text[i*500:(i+1)*500]
  213. if t:
  214. text_lists.append(t)
  215. else:
  216. text_lists.append(split_text)
  217. return text_lists, flag_img_info
  218. def file_split(self, file_content_list, doc_id):
  219. # TODO 根据文本列表进行切分 返回切分列表和存储图片的链接
  220. separator_num = self.file_json.get("set_slice")
  221. set_table = self.file_json.get("set_table")
  222. # separator = split_map.get(separator_num) if split_map.get(separator_num) else [slice_value]
  223. # logger.info(f"文本切分字符:{separator}")
  224. if isinstance(file_content_list, str):
  225. file_text = file_content_list
  226. text_lists = self.split_text(file_text)
  227. return text_lists, {}
  228. elif separator_num == "0":
  229. # 使用标题段落切分,使用text_level=1,2 切分即一个# 还是两个#
  230. text_lists, flag_img_info = self.split_by_title(file_content_list, set_table, doc_id)
  231. return text_lists, flag_img_info
  232. elif separator_num == "1":
  233. # 按照页面方式切分
  234. text_lists, flag_img_info = self.split_by_page(file_content_list, set_table, doc_id)
  235. return text_lists, flag_img_info
  236. elif separator_num == "2":
  237. # 按照问答对切分 针对exce文档,暂不实现
  238. return [], {}
  239. else:
  240. # 自定义切分的方式,按照自定义字符以及文本长度切分,超过500
  241. slice_value = self.file_json.get("slice_value", "").replace("\\n", "\n")
  242. text_lists, flag_img_info = self.split_by_self(file_content_list, set_table, slice_value, doc_id)
  243. return text_lists, flag_img_info
  244. def process_data_to_milvus_schema(self, text_lists, doc_id, name):
  245. """组织数据格式:
  246. {
  247. "content": text,
  248. "doc_id": doc_id,
  249. "chunk_id": chunk_id,
  250. "metadata": {"source": file_name},
  251. }
  252. """
  253. docs = []
  254. total_len = 0
  255. for i, text in enumerate(text_lists):
  256. chunk_id = str(uuid1())
  257. chunk_len = len(text)
  258. total_len += chunk_len
  259. d = {
  260. "content": text,
  261. "doc_id": doc_id,
  262. "chunk_id": chunk_id,
  263. "metadata": {"source": name, "chunk_index": i+1, "chunk_len": chunk_len}
  264. }
  265. docs.append(d)
  266. return docs, total_len
  267. async def process_documents(self, file_json):
  268. # 文档下载
  269. separator_num = file_json.get("set_slice")
  270. if separator_num == "2":
  271. return {"code": 500, "message": "暂不支持解析"}
  272. docs = file_json.get("docs")
  273. flag = file_json.get("flag")
  274. success_doc = [] # 记录解析成功的文档id
  275. for doc in docs:
  276. url = doc.get("url")
  277. name = doc.get("name")
  278. doc_id = doc.get("document_id")
  279. async with aiohttp.ClientSession() as session:
  280. down_file_name, file_size = await self.save_file_temp(session, url, name)
  281. file_parse = self._get_file_type(name)
  282. file_content_list = await file_parse.extract_text(down_file_name)
  283. logger.info(f"mineru解析的pdf数据:{file_content_list}")
  284. text_lists, flag_img_info = self.file_split(file_content_list, doc_id)
  285. docs, total_char_len = self.process_data_to_milvus_schema(text_lists, doc_id, name)
  286. logger.info(f"存储到milvus的文本数据:{docs}")
  287. if flag == "upload":
  288. # 插入到milvus库中
  289. insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, doc_id)
  290. if insert_slice_flag:
  291. # 插入到mysql的slice info数据库中
  292. insert_img_flag, insert_mysql_info = self.mysql_client.insert_to_image_url(flag_img_info, self.knowledge_id, doc_id)
  293. else:
  294. insert_img_flag = False
  295. parse_file_status = False
  296. if insert_img_flag:
  297. insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
  298. # 插入mysql中的bm_media_replacement表中
  299. else:
  300. # self.milvus_client._delete_by_doc_id(doc_id=doc_id)
  301. insert_milvus_flag = False
  302. # return resp
  303. parse_file_status = False
  304. if insert_milvus_flag:
  305. parse_file_status = True
  306. else:
  307. self.mysql_client.delete_to_slice(doc_id=doc_id)
  308. # self.milvus_client._delete_by_doc_id(doc_id=doc_id)
  309. self.mysql_client.delete_image_url(doc_id=doc_id)
  310. # resp = {"code": 500, "message": insert_mysql_info}
  311. parse_file_status = False
  312. # return resp
  313. elif flag == "update": # 更新切片方式
  314. # 先把库中的数据删除
  315. self.milvus_client._delete_by_doc_id(doc_id=doc_id)
  316. self.mysql_client.delete_to_slice(doc_id=doc_id)
  317. insert_milvus_start_time = time.time()
  318. insert_slice_flag, insert_mysql_info = self.mysql_client.insert_to_slice(docs, self.knowledge_id, doc_id)
  319. # insert_milvus_flag, insert_milvus_str = self.milvus_client._batch_insert_data(docs,text_lists)
  320. insert_milvus_end_time = time.time()
  321. logger.info(f"插入milvus数据库耗时:{insert_milvus_end_time - insert_milvus_start_time}")
  322. if insert_slice_flag:
  323. # 插入到mysql的slice info数据库中
  324. insert_mysql_start_time = time.time()
  325. insert_milvus_flag, insert_milvus_str = self.milvus_client._insert_data(docs)
  326. insert_mysql_end_time = time.time()
  327. logger.info(f"插入mysql数据库耗时:{insert_mysql_end_time - insert_mysql_start_time}")
  328. else:
  329. # resp = {"code": 500, "message": insert_milvus_str}
  330. # return resp
  331. insert_milvus_flag = False
  332. parse_file_status = False
  333. if insert_milvus_flag:
  334. # resp = {"code": 200, "message": "切片修改成功"}
  335. parse_file_status = True
  336. else:
  337. self.mysql_client.delete_to_slice(doc_id=doc_id)
  338. # self.milvus_client._delete_by_doc_id(doc_id=doc_id)
  339. # resp = {"code":500, "message": insert_mysql_info}
  340. parse_file_status = False
  341. # return resp
  342. if parse_file_status:
  343. success_doc.append(doc_id)
  344. else:
  345. if flag == "upload":
  346. for del_id in success_doc:
  347. self.milvus_client._delete_by_doc_id(doc_id=del_id)
  348. self.mysql_client.delete_image_url(doc_id=del_id)
  349. self.mysql_client.delete_to_slice(doc_id=del_id)
  350. return {"code": 500, "message": "解析失败", "knowledge_id" : self.knowledge_id, "doc_info": {}}
  351. return {"code": 200, "message": "解析成功", "knowledge_id" : self.knowledge_id, "doc_info": {"file_size": file_size, "total_char_len": total_char_len, "slice_num": len(text_lists)}}