hybrid_model_output_to_middle_json.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. from loguru import logger
  7. from mineru.backend.hybrid.hybrid_magic_model import MagicModel
  8. from mineru.backend.utils import cross_page_table_merge
  9. from mineru.utils.config_reader import get_table_enable, get_llm_aided_config
  10. from mineru.utils.cut_image import cut_image_and_table
  11. from mineru.utils.enum_class import ContentType
  12. from mineru.utils.hash_utils import bytes_md5
  13. from mineru.utils.ocr_utils import OcrConfidence
  14. from mineru.utils.pdf_image_tools import get_crop_img
  15. from mineru.version import __version__
  16. heading_level_import_success = False
  17. llm_aided_config = get_llm_aided_config()
  18. if llm_aided_config:
  19. title_aided_config = llm_aided_config.get('title_aided', {})
  20. if title_aided_config.get('enable', False):
  21. try:
  22. from mineru.utils.llm_aided import llm_aided_title
  23. from mineru.backend.pipeline.model_init import AtomModelSingleton
  24. heading_level_import_success = True
  25. except Exception as e:
  26. logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
  27. "please execute `pip install mineru[core]` to install the required packages.")
  28. def blocks_to_page_info(
  29. page_blocks,
  30. page_inline_formula,
  31. page_ocr_res,
  32. image_dict,
  33. page,
  34. image_writer,
  35. page_index,
  36. _ocr_enable,
  37. _vlm_ocr_enable,
  38. ) -> dict:
  39. """将blocks转换为页面信息"""
  40. scale = image_dict["scale"]
  41. page_pil_img = image_dict["img_pil"]
  42. page_img_md5 = bytes_md5(page_pil_img.tobytes())
  43. width, height = map(int, page.get_size())
  44. magic_model = MagicModel(
  45. page_blocks,
  46. page_inline_formula,
  47. page_ocr_res,
  48. page,
  49. scale,
  50. page_pil_img,
  51. width,
  52. height,
  53. _ocr_enable,
  54. _vlm_ocr_enable,
  55. )
  56. image_blocks = magic_model.get_image_blocks()
  57. table_blocks = magic_model.get_table_blocks()
  58. title_blocks = magic_model.get_title_blocks()
  59. discarded_blocks = magic_model.get_discarded_blocks()
  60. code_blocks = magic_model.get_code_blocks()
  61. ref_text_blocks = magic_model.get_ref_text_blocks()
  62. phonetic_blocks = magic_model.get_phonetic_blocks()
  63. list_blocks = magic_model.get_list_blocks()
  64. # 如果有标题优化需求,计算标题的平均行高
  65. if heading_level_import_success:
  66. if _vlm_ocr_enable: # vlm_ocr导致没有line信息,需要重新det获取平均行高
  67. atom_model_manager = AtomModelSingleton()
  68. ocr_model = atom_model_manager.get_atom_model(
  69. atom_model_name='ocr',
  70. ocr_show_log=False,
  71. det_db_box_thresh=0.3,
  72. lang='ch_lite'
  73. )
  74. for title_block in title_blocks:
  75. title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
  76. title_np_img = np.array(title_pil_img)
  77. # 给title_pil_img添加上下左右各50像素白边padding
  78. title_np_img = cv2.copyMakeBorder(
  79. title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
  80. )
  81. title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
  82. ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
  83. if len(ocr_det_res) > 0:
  84. # 计算所有res的平均高度
  85. avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
  86. title_block['line_avg_height'] = round(avg_height/scale)
  87. else: # 有line信息,直接计算平均行高
  88. for title_block in title_blocks:
  89. lines = title_block.get('lines', [])
  90. if lines:
  91. # 使用列表推导式和内置函数,一次性计算平均高度
  92. avg_height = sum(line['bbox'][3] - line['bbox'][1] for line in lines) / len(lines)
  93. title_block['line_avg_height'] = round(avg_height)
  94. else:
  95. title_block['line_avg_height'] = title_block['bbox'][3] - title_block['bbox'][1]
  96. text_blocks = magic_model.get_text_blocks()
  97. interline_equation_blocks = magic_model.get_interline_equation_blocks()
  98. all_spans = magic_model.get_all_spans()
  99. # 对image/table/interline_equation的span截图
  100. for span in all_spans:
  101. if span["type"] in [ContentType.IMAGE, ContentType.TABLE, ContentType.INTERLINE_EQUATION]:
  102. span = cut_image_and_table(span, page_pil_img, page_img_md5, page_index, image_writer, scale=scale)
  103. page_blocks = []
  104. page_blocks.extend([
  105. *image_blocks,
  106. *table_blocks,
  107. *code_blocks,
  108. *ref_text_blocks,
  109. *phonetic_blocks,
  110. *title_blocks,
  111. *text_blocks,
  112. *interline_equation_blocks,
  113. *list_blocks,
  114. ])
  115. # 对page_blocks根据index的值进行排序
  116. page_blocks.sort(key=lambda x: x["index"])
  117. page_info = {"para_blocks": page_blocks, "discarded_blocks": discarded_blocks, "page_size": [width, height], "page_idx": page_index}
  118. return page_info
  119. def result_to_middle_json(
  120. model_output_blocks_list,
  121. inline_formula_list,
  122. ocr_res_list,
  123. images_list,
  124. pdf_doc,
  125. image_writer,
  126. _ocr_enable,
  127. _vlm_ocr_enable,
  128. hybrid_pipeline_model,
  129. ):
  130. middle_json = {
  131. "pdf_info": [],
  132. "_backend": "hybrid",
  133. "_ocr_enable": _ocr_enable,
  134. "_vlm_ocr_enable": _vlm_ocr_enable,
  135. "_version_name": __version__
  136. }
  137. for index, (page_blocks, page_inline_formula, page_ocr_res) in enumerate(zip(model_output_blocks_list, inline_formula_list, ocr_res_list)):
  138. page = pdf_doc[index]
  139. image_dict = images_list[index]
  140. page_info = blocks_to_page_info(
  141. page_blocks, page_inline_formula, page_ocr_res,
  142. image_dict, page, image_writer, index,
  143. _ocr_enable, _vlm_ocr_enable
  144. )
  145. middle_json["pdf_info"].append(page_info)
  146. if not (_vlm_ocr_enable or _ocr_enable):
  147. """后置ocr处理"""
  148. need_ocr_list = []
  149. img_crop_list = []
  150. text_block_list = []
  151. for page_info in middle_json["pdf_info"]:
  152. for block in page_info['para_blocks']:
  153. if block['type'] in ['table', 'image', 'list', 'code']:
  154. for sub_block in block['blocks']:
  155. if not sub_block['type'].endswith('body'):
  156. text_block_list.append(sub_block)
  157. elif block['type'] in ['text', 'title', 'ref_text']:
  158. text_block_list.append(block)
  159. for block in page_info['discarded_blocks']:
  160. text_block_list.append(block)
  161. for block in text_block_list:
  162. for line in block['lines']:
  163. for span in line['spans']:
  164. if 'np_img' in span:
  165. need_ocr_list.append(span)
  166. img_crop_list.append(span['np_img'])
  167. span.pop('np_img')
  168. if len(img_crop_list) > 0:
  169. ocr_res_list = hybrid_pipeline_model.ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
  170. assert len(ocr_res_list) == len(
  171. need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
  172. for index, span in enumerate(need_ocr_list):
  173. ocr_text, ocr_score = ocr_res_list[index]
  174. if ocr_score > OcrConfidence.min_confidence:
  175. span['content'] = ocr_text
  176. span['score'] = float(f"{ocr_score:.3f}")
  177. else:
  178. span['content'] = ''
  179. span['score'] = 0.0
  180. """表格跨页合并"""
  181. table_enable = get_table_enable(os.getenv('MINERU_VLM_TABLE_ENABLE', 'True').lower() == 'true')
  182. if table_enable:
  183. cross_page_table_merge(middle_json["pdf_info"])
  184. """llm优化标题分级"""
  185. if heading_level_import_success:
  186. llm_aided_title_start_time = time.time()
  187. llm_aided_title(middle_json["pdf_info"], title_aided_config)
  188. logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
  189. # 关闭pdf文档
  190. pdf_doc.close()
  191. return middle_json