paddleocr_load.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import re
  5. import time
  6. import aiohttp
  7. from openai import AsyncOpenAI
  8. import json_repair
  9. from utils.get_logger import setup_logger
  10. from config import paddleocr_config
  11. import asyncio
  12. logger = setup_logger(__name__)
  13. class PaddleOCRLoader:
  14. """PaddleOCR-VL PDF 解析器"""
  15. def __init__(self, file_json):
  16. self.file_json = file_json
  17. self.knowledge_id = file_json.get("knowledge_id")
  18. self.document_id = None
  19. self.service_url = paddleocr_config.get("service_url", "http://127.0.0.1:8119")
  20. self.vl_server_url = paddleocr_config.get("vl_server_url", "http://127.0.0.1:8118/v1")
  21. self.output_dir = paddleocr_config.get("output_dir", "./tmp_file/paddleocr_parsed")
  22. self.timeout = paddleocr_config.get("timeout", 600)
  23. self.title_opt_config = paddleocr_config.get("title_optimization", {})
  24. logger.info(f"PaddleOCR Loader 初始化成功, service_url={self.service_url}")
  25. async def extract_text(self, pdf_path, doc_id=None):
  26. """
  27. 解析 PDF 文件
  28. 返回:
  29. (content_list, md_path, pdf_file_name) - 与其他解析器保持一致的返回格式
  30. """
  31. self.document_id = doc_id
  32. logger.info(f"开始使用 PaddleOCR-VL 解析 PDF: {pdf_path}, doc_id: {doc_id}")
  33. try:
  34. # 调用 PaddleOCR-VL 服务
  35. json_data, md_path, pdf_file_name = await self._call_service(pdf_path)
  36. if not json_data:
  37. raise ValueError("PaddleOCR-VL 服务返回空结果")
  38. # 转换为统一格式
  39. content_list = self._convert_to_content_list(json_data, pdf_file_name)
  40. # 标题优化(可选,默认不启用)
  41. content_list = await self._optimize_titles(content_list)
  42. logger.info(f"PaddleOCR-VL 解析完成,共 {len(content_list)} 个元素")
  43. return content_list, md_path, pdf_file_name
  44. except Exception as e:
  45. logger.error(f"PaddleOCR-VL 解析失败: {e}", exc_info=True)
  46. raise
  47. async def _call_service(self, pdf_path):
  48. """调用 PaddleOCR-VL 服务"""
  49. url = f"{self.service_url}/parse"
  50. payload = {
  51. "pdf_path": pdf_path,
  52. "output_dir": self.output_dir,
  53. "vl_rec_server_url": self.vl_server_url
  54. }
  55. MAX_RETRY = 5
  56. RETRY_INTERVAL = 10
  57. last_exception = None
  58. for attempt in range(1, MAX_RETRY + 1):
  59. try:
  60. # connector = aiohttp.TCPConnector(force_close=True) # 禁用 HTTP keep-alive
  61. # connector = aiohttp.TCPConnector(limit=50) # 将并发数量降低
  62. async with aiohttp.ClientSession() as session:
  63. async with session.post(
  64. url,
  65. json=payload,
  66. timeout=aiohttp.ClientTimeout(total=self.timeout)
  67. ) as resp:
  68. if resp.status != 200:
  69. raise ValueError(f"服务返回错误状态: {resp.status}")
  70. result = await resp.json()
  71. if result.get("code") != 200:
  72. raise ValueError(f"服务返回错误: {result.get('message')}")
  73. data = result.get("data", {})
  74. return (
  75. data.get("json_data"),
  76. data.get("md_path"),
  77. data.get("pdf_file_name")
  78. )
  79. except aiohttp.ClientError as e:
  80. last_exception = e
  81. logger.warning(f"PaddleOCR 服务调用失败 (尝试 {attempt}/{MAX_RETRY}): {e}")
  82. if attempt < MAX_RETRY:
  83. logger.info(f"等待 {RETRY_INTERVAL}s 后重试...")
  84. await asyncio.sleep(RETRY_INTERVAL)
  85. # 所有重试失败后才抛出异常
  86. raise RuntimeError(
  87. f"PaddleOCR predict 连续失败 {MAX_RETRY} 次,最后错误: {last_exception}"
  88. )
  89. def _convert_to_content_list(self, json_data, pdf_file_name):
  90. """
  91. 将 PaddleOCR-VL JSON 数据转换为统一的 content_list 格式
  92. block_label 类型映射:
  93. - text: 普通文本
  94. - paragraph_title: 标题(根据#号确定级别)
  95. - header: 页眉(排除)
  96. - figure_title: 表格/图片标题
  97. - table: 表格
  98. - number: 页码(排除)
  99. - content: 文本块 -> text
  100. - doc_title: 标题(默认2级)
  101. - image: 图片
  102. - vision_footnote: 页脚(排除)
  103. """
  104. content_list = []
  105. pages = json_data.get("pages", [])
  106. for page_data in pages:
  107. res = page_data.get("res", {})
  108. # 处理嵌套的 res 结构
  109. if "res" in res:
  110. res = res.get("res", {})
  111. page_idx = res.get("page_index", 0)
  112. parsing_res_list = res.get("parsing_res_list", [])
  113. # 处理 figure_title + table 组合
  114. i = 0
  115. while i < len(parsing_res_list):
  116. block = parsing_res_list[i]
  117. block_label = block.get("block_label", "")
  118. block_content = block.get("block_content", "")
  119. block_bbox = block.get("block_bbox", [])
  120. # 排除页眉、页脚、页码
  121. if block_label in ("header", "number", "vision_footnote"):
  122. i += 1
  123. continue
  124. # figure_title 处理
  125. if block_label == "figure_title":
  126. # 检查下一个是否是 table
  127. if i + 1 < len(parsing_res_list):
  128. next_block = parsing_res_list[i + 1]
  129. if next_block.get("block_label") == "table":
  130. # 将 figure_title 作为 table_caption
  131. table_content = self._process_table(
  132. next_block, page_idx, table_caption=block_content
  133. )
  134. if table_content:
  135. content_list.append(table_content)
  136. i += 2 # 跳过当前和下一个
  137. continue
  138. # 否则当作 text 处理
  139. content_list.append({
  140. "type": "text",
  141. "text": block_content,
  142. "text_level": None,
  143. "page_idx": page_idx,
  144. "bbox": block_bbox
  145. })
  146. i += 1
  147. continue
  148. # 处理各种类型
  149. if block_label == "text" or block_label == "content":
  150. content_list.append({
  151. "type": "text",
  152. "text": block_content,
  153. "text_level": None,
  154. "page_idx": page_idx,
  155. "bbox": block_bbox
  156. })
  157. elif block_label == "paragraph_title":
  158. text_level = self._extract_title_level(block_content)
  159. clean_text = self._remove_title_prefix(block_content)
  160. content_list.append({
  161. "type": "text",
  162. "text": clean_text,
  163. "text_level": text_level,
  164. "page_idx": page_idx,
  165. "bbox": block_bbox
  166. })
  167. elif block_label == "doc_title":
  168. # doc_title 默认为 2 级标题
  169. content_list.append({
  170. "type": "text",
  171. "text": block_content,
  172. "text_level": 2,
  173. "page_idx": page_idx,
  174. "bbox": block_bbox
  175. })
  176. elif block_label == "table":
  177. table_content = self._process_table(block, page_idx)
  178. if table_content:
  179. content_list.append(table_content)
  180. elif block_label == "image":
  181. image_content = self._process_image(block, page_idx, pdf_file_name)
  182. if image_content:
  183. content_list.append(image_content)
  184. i += 1
  185. return content_list
  186. def _extract_title_level(self, text):
  187. """从文本开头的 # 号提取标题级别"""
  188. if not text:
  189. return None
  190. match = re.match(r'^(#{1,6})\s+', text)
  191. if match:
  192. return len(match.group(1))
  193. return None
  194. def _remove_title_prefix(self, text):
  195. """移除文本开头的 # 号前缀"""
  196. if not text:
  197. return text
  198. return re.sub(r'^#{1,6}\s+', '', text)
  199. def _process_table(self, block, page_idx, table_caption=None):
  200. """处理表格块"""
  201. block_content = block.get("block_content", "")
  202. block_bbox = block.get("block_bbox", [])
  203. img_path = block.get("img_path", "")
  204. return {
  205. "type": "table",
  206. "table_body": block_content,
  207. "table_caption": [table_caption] if table_caption else [],
  208. "page_idx": page_idx,
  209. "bbox": block_bbox,
  210. "img_path": img_path
  211. }
  212. def _process_image(self, block, page_idx, pdf_file_name):
  213. """
  214. 处理图片块
  215. 图片内容格式:
  216. <div style="text-align: center;"><img src="imgs/img_xxx.jpg" ... /></div>
  217. 需要提取图片 URL
  218. """
  219. block_content = block.get("block_content", "")
  220. block_bbox = block.get("block_bbox", [])
  221. # 提取图片路径
  222. match = re.search(r'src="([^"]+)"', block_content)
  223. if match:
  224. img_path = match.group(1)
  225. return {
  226. "type": "image",
  227. "img_path": img_path,
  228. "page_idx": page_idx,
  229. "bbox": block_bbox
  230. }
  231. return None
  232. def _collect_title_candidates(self, content_list):
  233. """收集需要优化的标题候选项"""
  234. title_dict = {}
  235. origin_indices = []
  236. idx = 0
  237. for i, item in enumerate(content_list):
  238. if item.get("type") != "text":
  239. continue
  240. text = item.get("text", "")
  241. text_level = item.get("text_level")
  242. bbox = item.get("bbox", [0, 0, 0, 0])
  243. line_height = int(bbox[3] - bbox[1]) if len(bbox) >= 4 else 0
  244. page_idx = item.get("page_idx", 0) + 1
  245. # 条件1: 已有 text_level(已识别为标题)
  246. if text_level is not None:
  247. title_dict[str(idx)] = [text, line_height, page_idx]
  248. origin_indices.append(i)
  249. idx += 1
  250. # 条件2: 以数字编号开头(可能遗漏的标题)
  251. elif re.match(r'^\d+(\.\d+)*([、.\s]|(?=[a-zA-Z\u4e00-\u9fa5]))', text):
  252. title_dict[str(idx)] = [text, line_height, page_idx]
  253. origin_indices.append(i)
  254. idx += 1
  255. return title_dict, origin_indices
  256. async def _optimize_titles_with_llm(self, title_dict):
  257. """调用 LLM 优化标题层级"""
  258. if not title_dict:
  259. return None, None
  260. client = AsyncOpenAI(
  261. api_key=self.title_opt_config.get("api_key"),
  262. base_url=self.title_opt_config.get("base_url"),
  263. )
  264. prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构:
  265. 1. 字典中每个value均为一个list,包含以下元素:
  266. - 标题文本
  267. - 文本行高是标题所在块的平均行高
  268. - 标题所在的页码
  269. 2. 保留原始内容:
  270. - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素
  271. - 请务必保证输出的字典中元素的数量和输入的数量一致
  272. 3. 保持字典内key-value的对应关系不变
  273. 4. 优化层次结构:
  274. - 为每个标题元素添加适当的层次结构
  275. - 行高较大的标题一般是更高级别的标题
  276. - 标题从前至后的层级必须是连续的,不能跳过层级
  277. - 标题层级最多为5级,不要添加过多的层级
  278. - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息
  279. 5. 合理性检查与微调:
  280. - 在完成初步分级后,仔细检查分级结果的合理性
  281. - 根据上下文关系和逻辑顺序,对不合理的分级进行微调
  282. - 确保最终的分级结果符合文档的实际结构和逻辑
  283. - 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们
  284. - 一般的,如开头部分存在特殊字符,如*、#、&等,例如"1.0*DL、1.0#DL、1.0&DL",你可以将这些字符的标题层级标记为 0 来排除它们
  285. 6. 生成树形路径:
  286. - 从标题文本中提取编号(如"3.1 安全管理"提取"3.1"),无编号则用标题文本
  287. - 遇到1级标题时,以它作为新的根节点,路径就是它自身的编号/文本
  288. - 后续非1级标题的路径 = 当前根节点 + "->" + 各级父标题编号 + "->" + 自身编号
  289. - 遇到下一个1级标题时,切换为新的根节点,重复上述逻辑
  290. - 层级为0的非标题项,路径标记为"0"
  291. 完整示例(注意paths的value是从标题文本提取的编号,不是标题id):
  292. 输入:
  293. {{"0":["前言",24,1],"1":["目次",24,2],"2":["1 总则",22,3],"3":["1.0.1 为科学评价",18,3],"4":["1.0.2 本标准适用",18,3],"5":["3 检查评定项目",22,5],"6":["3.1 安全管理",20,5],"7":["3.1.3 保证项目",18,6],"8":["1 安全生产责任制",16,6],"9":["2 施工组织设计",16,6]}}
  294. 输出:
  295. {{"levels":{{"0":1,"1":1,"2":1,"3":2,"4":2,"5":1,"6":2,"7":3,"8":4,"9":4}},"paths":{{"0":"前言","1":"目次","2":"1","3":"1->1.0.1","4":"1->1.0.2","5":"3","6":"3->3.1","7":"3->3.1->3.1.3","8":"3->3.1->3.1.3->1","9":"3->3.1->3.1.3->2"}}}}
  296. IMPORTANT:
  297. 请返回一个JSON对象,包含两个字段:
  298. - "levels": 层级字典,key是标题id,value是层级数字
  299. - "paths": 路径字典,key是标题id,value是从标题文本提取编号后构建的路径
  300. 不需要对JSON格式化,不需要返回任何其他信息。
  301. Input title list:
  302. {title_dict}
  303. Corrected title list:
  304. """
  305. max_retries = self.title_opt_config.get("max_retries", 3)
  306. model = self.title_opt_config.get("model", "deepseek-chat")
  307. for retry in range(max_retries):
  308. try:
  309. completion = await client.chat.completions.create(
  310. model=model,
  311. messages=[{'role': 'user', 'content': prompt}],
  312. temperature=0.7,
  313. stream=True,
  314. )
  315. content_pieces = []
  316. async for chunk in completion:
  317. if chunk.choices and chunk.choices[0].delta.content:
  318. content_pieces.append(chunk.choices[0].delta.content)
  319. content = "".join(content_pieces).strip()
  320. logger.info(f"LLM 标题优化响应: {content[:200]}...")
  321. if "</think>" in content:
  322. content = content[content.index("</think>") + len("</think>"):].strip()
  323. result = json_repair.loads(content)
  324. levels_dict = {int(k): int(v) for k, v in result.get("levels", {}).items()}
  325. paths_dict = {int(k): str(v) for k, v in result.get("paths", {}).items()}
  326. if len(levels_dict) == len(title_dict):
  327. return levels_dict, paths_dict
  328. else:
  329. logger.warning(f"LLM返回数量不匹配: {len(levels_dict)} vs {len(title_dict)}")
  330. except Exception as e:
  331. logger.warning(f"LLM调用失败 (retry {retry+1}/{max_retries}): {e}")
  332. logger.error("LLM标题优化达到最大重试次数")
  333. return None, None
  334. def _apply_title_optimization(self, content_list, origin_indices, levels_dict, paths_dict):
  335. """将优化结果回写到 content_list"""
  336. for i, idx in enumerate(origin_indices):
  337. level = levels_dict.get(i, 0)
  338. content_list[idx]["text_level"] = level if level > 0 else None
  339. if paths_dict and i in paths_dict:
  340. content_list[idx]["title_path"] = paths_dict[i]
  341. if level == 0:
  342. content_list[idx].pop("title_path", None)
  343. async def _optimize_titles(self, content_list):
  344. """标题优化入口"""
  345. if self.title_opt_config.get("enable", False):
  346. return content_list
  347. try:
  348. title_dict, origin_indices = self._collect_title_candidates(content_list)
  349. if not title_dict:
  350. logger.info("未发现需要优化的标题")
  351. return content_list
  352. logger.info(f"收集到 {len(title_dict)} 个标题候选项,开始LLM优化")
  353. levels_dict, paths_dict = await self._optimize_titles_with_llm(title_dict)
  354. if levels_dict:
  355. self._apply_title_optimization(content_list, origin_indices, levels_dict, paths_dict)
  356. logger.info("标题优化完成")
  357. else:
  358. logger.warning("标题优化失败,保留原始数据")
  359. except Exception as e:
  360. logger.error(f"标题优化异常: {e}", exc_info=True)
  361. return content_list