#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import re import time import aiohttp from openai import AsyncOpenAI import json_repair from utils.get_logger import setup_logger from config import paddleocr_config import asyncio logger = setup_logger(__name__) class PaddleOCRLoader: """PaddleOCR-VL PDF 解析器""" def __init__(self, file_json): self.file_json = file_json self.knowledge_id = file_json.get("knowledge_id") self.document_id = None self.service_url = paddleocr_config.get("service_url", "http://127.0.0.1:8119") self.vl_server_url = paddleocr_config.get("vl_server_url", "http://127.0.0.1:8118/v1") self.output_dir = paddleocr_config.get("output_dir", "./tmp_file/paddleocr_parsed") self.timeout = paddleocr_config.get("timeout", 600) self.title_opt_config = paddleocr_config.get("title_optimization", {}) logger.info(f"PaddleOCR Loader 初始化成功, service_url={self.service_url}") async def extract_text(self, pdf_path, doc_id=None): """ 解析 PDF 文件 返回: (content_list, md_path, pdf_file_name) - 与其他解析器保持一致的返回格式 """ self.document_id = doc_id logger.info(f"开始使用 PaddleOCR-VL 解析 PDF: {pdf_path}, doc_id: {doc_id}") try: # 调用 PaddleOCR-VL 服务 json_data, md_path, pdf_file_name = await self._call_service(pdf_path) if not json_data: raise ValueError("PaddleOCR-VL 服务返回空结果") # 转换为统一格式 content_list = self._convert_to_content_list(json_data, pdf_file_name) # 标题优化(可选,默认不启用) content_list = await self._optimize_titles(content_list) logger.info(f"PaddleOCR-VL 解析完成,共 {len(content_list)} 个元素") return content_list, md_path, pdf_file_name except Exception as e: logger.error(f"PaddleOCR-VL 解析失败: {e}", exc_info=True) raise async def _call_service(self, pdf_path): """调用 PaddleOCR-VL 服务""" url = f"{self.service_url}/parse" payload = { "pdf_path": pdf_path, "output_dir": self.output_dir, "vl_rec_server_url": self.vl_server_url } MAX_RETRY = 5 RETRY_INTERVAL = 10 last_exception = None for attempt in range(1, MAX_RETRY + 1): try: # connector = aiohttp.TCPConnector(force_close=True) # 禁用 HTTP keep-alive # connector = aiohttp.TCPConnector(limit=50) # 将并发数量降低 async with aiohttp.ClientSession() as session: async with session.post( url, json=payload, timeout=aiohttp.ClientTimeout(total=self.timeout) ) as resp: if resp.status != 200: raise ValueError(f"服务返回错误状态: {resp.status}") result = await resp.json() if result.get("code") != 200: raise ValueError(f"服务返回错误: {result.get('message')}") data = result.get("data", {}) return ( data.get("json_data"), data.get("md_path"), data.get("pdf_file_name") ) except aiohttp.ClientError as e: last_exception = e logger.warning(f"PaddleOCR 服务调用失败 (尝试 {attempt}/{MAX_RETRY}): {e}") if attempt < MAX_RETRY: logger.info(f"等待 {RETRY_INTERVAL}s 后重试...") await asyncio.sleep(RETRY_INTERVAL) # 所有重试失败后才抛出异常 raise RuntimeError( f"PaddleOCR predict 连续失败 {MAX_RETRY} 次,最后错误: {last_exception}" ) def _convert_to_content_list(self, json_data, pdf_file_name): """ 将 PaddleOCR-VL JSON 数据转换为统一的 content_list 格式 block_label 类型映射: - text: 普通文本 - paragraph_title: 标题(根据#号确定级别) - header: 页眉(排除) - figure_title: 表格/图片标题 - table: 表格 - number: 页码(排除) - content: 文本块 -> text - doc_title: 标题(默认2级) - image: 图片 - vision_footnote: 页脚(排除) """ content_list = [] pages = json_data.get("pages", []) for page_data in pages: res = page_data.get("res", {}) # 处理嵌套的 res 结构 if "res" in res: res = res.get("res", {}) page_idx = res.get("page_index", 0) parsing_res_list = res.get("parsing_res_list", []) # 处理 figure_title + table 组合 i = 0 while i < len(parsing_res_list): block = parsing_res_list[i] block_label = block.get("block_label", "") block_content = block.get("block_content", "") block_bbox = block.get("block_bbox", []) # 排除页眉、页脚、页码 if block_label in ("header", "number", "vision_footnote"): i += 1 continue # figure_title 处理 if block_label == "figure_title": # 检查下一个是否是 table if i + 1 < len(parsing_res_list): next_block = parsing_res_list[i + 1] if next_block.get("block_label") == "table": # 将 figure_title 作为 table_caption table_content = self._process_table( next_block, page_idx, table_caption=block_content ) if table_content: content_list.append(table_content) i += 2 # 跳过当前和下一个 continue # 否则当作 text 处理 content_list.append({ "type": "text", "text": block_content, "text_level": None, "page_idx": page_idx, "bbox": block_bbox }) i += 1 continue # 处理各种类型 if block_label == "text" or block_label == "content": content_list.append({ "type": "text", "text": block_content, "text_level": None, "page_idx": page_idx, "bbox": block_bbox }) elif block_label == "paragraph_title": text_level = self._extract_title_level(block_content) clean_text = self._remove_title_prefix(block_content) content_list.append({ "type": "text", "text": clean_text, "text_level": text_level, "page_idx": page_idx, "bbox": block_bbox }) elif block_label == "doc_title": # doc_title 默认为 2 级标题 content_list.append({ "type": "text", "text": block_content, "text_level": 2, "page_idx": page_idx, "bbox": block_bbox }) elif block_label == "table": table_content = self._process_table(block, page_idx) if table_content: content_list.append(table_content) elif block_label == "image": image_content = self._process_image(block, page_idx, pdf_file_name) if image_content: content_list.append(image_content) i += 1 return content_list def _extract_title_level(self, text): """从文本开头的 # 号提取标题级别""" if not text: return None match = re.match(r'^(#{1,6})\s+', text) if match: return len(match.group(1)) return None def _remove_title_prefix(self, text): """移除文本开头的 # 号前缀""" if not text: return text return re.sub(r'^#{1,6}\s+', '', text) def _process_table(self, block, page_idx, table_caption=None): """处理表格块""" block_content = block.get("block_content", "") block_bbox = block.get("block_bbox", []) img_path = block.get("img_path", "") return { "type": "table", "table_body": block_content, "table_caption": [table_caption] if table_caption else [], "page_idx": page_idx, "bbox": block_bbox, "img_path": img_path } def _process_image(self, block, page_idx, pdf_file_name): """ 处理图片块 图片内容格式:
需要提取图片 URL """ block_content = block.get("block_content", "") block_bbox = block.get("block_bbox", []) # 提取图片路径 match = re.search(r'src="([^"]+)"', block_content) if match: img_path = match.group(1) return { "type": "image", "img_path": img_path, "page_idx": page_idx, "bbox": block_bbox } return None def _collect_title_candidates(self, content_list): """收集需要优化的标题候选项""" title_dict = {} origin_indices = [] idx = 0 for i, item in enumerate(content_list): if item.get("type") != "text": continue text = item.get("text", "") text_level = item.get("text_level") bbox = item.get("bbox", [0, 0, 0, 0]) line_height = int(bbox[3] - bbox[1]) if len(bbox) >= 4 else 0 page_idx = item.get("page_idx", 0) + 1 # 条件1: 已有 text_level(已识别为标题) if text_level is not None: title_dict[str(idx)] = [text, line_height, page_idx] origin_indices.append(i) idx += 1 # 条件2: 以数字编号开头(可能遗漏的标题) elif re.match(r'^\d+(\.\d+)*([、.\s]|(?=[a-zA-Z\u4e00-\u9fa5]))', text): title_dict[str(idx)] = [text, line_height, page_idx] origin_indices.append(i) idx += 1 return title_dict, origin_indices async def _optimize_titles_with_llm(self, title_dict): """调用 LLM 优化标题层级""" if not title_dict: return None, None client = AsyncOpenAI( api_key=self.title_opt_config.get("api_key"), base_url=self.title_opt_config.get("base_url"), ) prompt = f"""输入的内容是一篇文档中所有标题组成的字典,请根据以下指南优化标题的结果,使结果符合正常文档的层次结构: 1. 字典中每个value均为一个list,包含以下元素: - 标题文本 - 文本行高是标题所在块的平均行高 - 标题所在的页码 2. 保留原始内容: - 输入的字典中所有元素都是有效的,不能删除字典中的任何元素 - 请务必保证输出的字典中元素的数量和输入的数量一致 3. 保持字典内key-value的对应关系不变 4. 优化层次结构: - 为每个标题元素添加适当的层次结构 - 行高较大的标题一般是更高级别的标题 - 标题从前至后的层级必须是连续的,不能跳过层级 - 标题层级最多为5级,不要添加过多的层级 - 优化后的标题只保留代表该标题的层级的整数,不要保留其他信息 5. 合理性检查与微调: - 在完成初步分级后,仔细检查分级结果的合理性 - 根据上下文关系和逻辑顺序,对不合理的分级进行微调 - 确保最终的分级结果符合文档的实际结构和逻辑 - 字典中可能包含被误当成标题的正文,你可以通过将其层级标记为 0 来排除它们 - 一般的,如开头部分存在特殊字符,如*、#、&等,例如"1.0*DL、1.0#DL、1.0&DL",你可以将这些字符的标题层级标记为 0 来排除它们 6. 生成树形路径: - 从标题文本中提取编号(如"3.1 安全管理"提取"3.1"),无编号则用标题文本 - 遇到1级标题时,以它作为新的根节点,路径就是它自身的编号/文本 - 后续非1级标题的路径 = 当前根节点 + "->" + 各级父标题编号 + "->" + 自身编号 - 遇到下一个1级标题时,切换为新的根节点,重复上述逻辑 - 层级为0的非标题项,路径标记为"0" 完整示例(注意paths的value是从标题文本提取的编号,不是标题id): 输入: {{"0":["前言",24,1],"1":["目次",24,2],"2":["1 总则",22,3],"3":["1.0.1 为科学评价",18,3],"4":["1.0.2 本标准适用",18,3],"5":["3 检查评定项目",22,5],"6":["3.1 安全管理",20,5],"7":["3.1.3 保证项目",18,6],"8":["1 安全生产责任制",16,6],"9":["2 施工组织设计",16,6]}} 输出: {{"levels":{{"0":1,"1":1,"2":1,"3":2,"4":2,"5":1,"6":2,"7":3,"8":4,"9":4}},"paths":{{"0":"前言","1":"目次","2":"1","3":"1->1.0.1","4":"1->1.0.2","5":"3","6":"3->3.1","7":"3->3.1->3.1.3","8":"3->3.1->3.1.3->1","9":"3->3.1->3.1.3->2"}}}} IMPORTANT: 请返回一个JSON对象,包含两个字段: - "levels": 层级字典,key是标题id,value是层级数字 - "paths": 路径字典,key是标题id,value是从标题文本提取编号后构建的路径 不需要对JSON格式化,不需要返回任何其他信息。 Input title list: {title_dict} Corrected title list: """ max_retries = self.title_opt_config.get("max_retries", 3) model = self.title_opt_config.get("model", "deepseek-chat") for retry in range(max_retries): try: completion = await client.chat.completions.create( model=model, messages=[{'role': 'user', 'content': prompt}], temperature=0.7, stream=True, ) content_pieces = [] async for chunk in completion: if chunk.choices and chunk.choices[0].delta.content: content_pieces.append(chunk.choices[0].delta.content) content = "".join(content_pieces).strip() logger.info(f"LLM 标题优化响应: {content[:200]}...") if "" in content: content = content[content.index("") + len(""):].strip() result = json_repair.loads(content) levels_dict = {int(k): int(v) for k, v in result.get("levels", {}).items()} paths_dict = {int(k): str(v) for k, v in result.get("paths", {}).items()} if len(levels_dict) == len(title_dict): return levels_dict, paths_dict else: logger.warning(f"LLM返回数量不匹配: {len(levels_dict)} vs {len(title_dict)}") except Exception as e: logger.warning(f"LLM调用失败 (retry {retry+1}/{max_retries}): {e}") logger.error("LLM标题优化达到最大重试次数") return None, None def _apply_title_optimization(self, content_list, origin_indices, levels_dict, paths_dict): """将优化结果回写到 content_list""" for i, idx in enumerate(origin_indices): level = levels_dict.get(i, 0) content_list[idx]["text_level"] = level if level > 0 else None if paths_dict and i in paths_dict: content_list[idx]["title_path"] = paths_dict[i] if level == 0: content_list[idx].pop("title_path", None) async def _optimize_titles(self, content_list): """标题优化入口""" if self.title_opt_config.get("enable", False): return content_list try: title_dict, origin_indices = self._collect_title_candidates(content_list) if not title_dict: logger.info("未发现需要优化的标题") return content_list logger.info(f"收集到 {len(title_dict)} 个标题候选项,开始LLM优化") levels_dict, paths_dict = await self._optimize_titles_with_llm(title_dict) if levels_dict: self._apply_title_optimization(content_list, origin_indices, levels_dict, paths_dict) logger.info("标题优化完成") else: logger.warning("标题优化失败,保留原始数据") except Exception as e: logger.error(f"标题优化异常: {e}", exc_info=True) return content_list