span_pre_proc.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import collections
  3. import math
  4. import re
  5. import statistics
  6. import cv2
  7. import numpy as np
  8. from loguru import logger
  9. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio, calculate_iou, \
  10. get_minbox_if_overlap_by_ratio
  11. from mineru.utils.enum_class import BlockType, ContentType
  12. from mineru.utils.pdf_image_tools import get_crop_img
  13. from mineru.utils.pdf_text_tool import get_page
  14. def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
  15. def get_block_bboxes(blocks, block_type_list):
  16. return [block[0:4] for block in blocks if block[7] in block_type_list]
  17. image_bboxes = get_block_bboxes(all_bboxes, [BlockType.IMAGE_BODY])
  18. table_bboxes = get_block_bboxes(all_bboxes, [BlockType.TABLE_BODY])
  19. other_block_type = []
  20. for block_type in BlockType.__dict__.values():
  21. if not isinstance(block_type, str):
  22. continue
  23. if block_type not in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY]:
  24. other_block_type.append(block_type)
  25. other_block_bboxes = get_block_bboxes(all_bboxes, other_block_type)
  26. discarded_block_bboxes = get_block_bboxes(all_discarded_blocks, [BlockType.DISCARDED])
  27. new_spans = []
  28. for span in spans:
  29. span_bbox = span['bbox']
  30. span_type = span['type']
  31. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.4 for block_bbox in
  32. discarded_block_bboxes):
  33. new_spans.append(span)
  34. continue
  35. if span_type == ContentType.IMAGE:
  36. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  37. image_bboxes):
  38. new_spans.append(span)
  39. elif span_type == ContentType.TABLE:
  40. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  41. table_bboxes):
  42. new_spans.append(span)
  43. else:
  44. if any(calculate_overlap_area_in_bbox1_area_ratio(span_bbox, block_bbox) > 0.5 for block_bbox in
  45. other_block_bboxes):
  46. new_spans.append(span)
  47. return new_spans
  48. def remove_overlaps_low_confidence_spans(spans):
  49. dropped_spans = []
  50. # 删除重叠spans中置信度低的的那些
  51. for span1 in spans:
  52. for span2 in spans:
  53. if span1 != span2:
  54. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  55. if span1 in dropped_spans or span2 in dropped_spans:
  56. continue
  57. else:
  58. if calculate_iou(span1['bbox'], span2['bbox']) > 0.9:
  59. if span1['score'] < span2['score']:
  60. span_need_remove = span1
  61. else:
  62. span_need_remove = span2
  63. if (
  64. span_need_remove is not None
  65. and span_need_remove not in dropped_spans
  66. ):
  67. dropped_spans.append(span_need_remove)
  68. if len(dropped_spans) > 0:
  69. for span_need_remove in dropped_spans:
  70. spans.remove(span_need_remove)
  71. return spans, dropped_spans
  72. def remove_overlaps_min_spans(spans):
  73. dropped_spans = []
  74. # 删除重叠spans中较小的那些
  75. for span1 in spans:
  76. for span2 in spans:
  77. if span1 != span2:
  78. # span1 或 span2 任何一个都不应该在 dropped_spans 中
  79. if span1 in dropped_spans or span2 in dropped_spans:
  80. continue
  81. else:
  82. overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.65)
  83. if overlap_box is not None:
  84. span_need_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
  85. if span_need_remove is not None and span_need_remove not in dropped_spans:
  86. dropped_spans.append(span_need_remove)
  87. if len(dropped_spans) > 0:
  88. for span_need_remove in dropped_spans:
  89. spans.remove(span_need_remove)
  90. return spans, dropped_spans
  91. def __replace_ligatures(text: str):
  92. ligatures = {
  93. 'fi': 'fi', 'fl': 'fl', 'ff': 'ff', 'ffi': 'ffi', 'ffl': 'ffl', 'ſt': 'ft', 'st': 'st'
  94. }
  95. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  96. def __replace_unicode(text: str):
  97. ligatures = {
  98. '\r\n': '', '\u0002': '-',
  99. }
  100. return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
  101. """pdf_text dict方案 char级别"""
  102. def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
  103. page_dict = get_page(pdf_page)
  104. page_all_chars = []
  105. page_all_lines = []
  106. for block in page_dict['blocks']:
  107. for line in block['lines']:
  108. rotation_degrees = math.degrees(line['rotation'])
  109. # 旋转角度不为0, 90, 180, 270的行,直接跳过(rotation_degrees的值可能不为整数)
  110. if not any(abs(rotation_degrees - angle) < 0.1 for angle in [0, 90, 180, 270]):
  111. continue
  112. page_all_lines.append(line)
  113. for span in line['spans']:
  114. for char in span['chars']:
  115. page_all_chars.append(char)
  116. # 计算所有sapn的高度的中位数
  117. span_height_list = []
  118. for span in spans:
  119. if span['type'] in [ContentType.TEXT]:
  120. span_height = span['bbox'][3] - span['bbox'][1]
  121. span['height'] = span_height
  122. span['width'] = span['bbox'][2] - span['bbox'][0]
  123. span_height_list.append(span_height)
  124. if len(span_height_list) == 0:
  125. return spans
  126. else:
  127. median_span_height = statistics.median(span_height_list)
  128. useful_spans = []
  129. unuseful_spans = []
  130. # 纵向span的两个特征:1. 高度超过多个line 2. 高宽比超过某个值
  131. vertical_spans = []
  132. for span in spans:
  133. if span['type'] in [ContentType.TEXT]:
  134. for block in all_bboxes + all_discarded_blocks:
  135. if block[7] in [BlockType.IMAGE_BODY, BlockType.TABLE_BODY, BlockType.INTERLINE_EQUATION]:
  136. continue
  137. if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
  138. if span['height'] > median_span_height * 2.3 and span['height'] > span['width'] * 2.3:
  139. vertical_spans.append(span)
  140. elif block in all_bboxes:
  141. useful_spans.append(span)
  142. else:
  143. unuseful_spans.append(span)
  144. break
  145. """垂直的span框直接用line进行填充"""
  146. if len(vertical_spans) > 0:
  147. for pdfium_line in page_all_lines:
  148. for span in vertical_spans:
  149. if calculate_overlap_area_in_bbox1_area_ratio(pdfium_line['bbox'].bbox, span['bbox']) > 0.5:
  150. for pdfium_span in pdfium_line['spans']:
  151. span['content'] += pdfium_span['text']
  152. break
  153. for span in vertical_spans:
  154. if len(span['content']) == 0:
  155. spans.remove(span)
  156. """水平的span框先用char填充,再用ocr填充空的span框"""
  157. new_spans = []
  158. for span in useful_spans + unuseful_spans:
  159. if span['type'] in [ContentType.TEXT]:
  160. span['chars'] = []
  161. new_spans.append(span)
  162. need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars, median_span_height)
  163. """对未填充的span进行ocr"""
  164. if len(need_ocr_spans) > 0:
  165. for span in need_ocr_spans:
  166. # 对span的bbox截图再ocr
  167. span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
  168. span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
  169. # 计算span的对比度,低于0.20的span不进行ocr
  170. if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
  171. spans.remove(span)
  172. continue
  173. span['content'] = ''
  174. span['score'] = 1.0
  175. span['np_img'] = span_img
  176. return spans
  177. def fill_char_in_spans(spans, all_chars, median_span_height):
  178. # 简单从上到下排一下序
  179. spans = sorted(spans, key=lambda x: x['bbox'][1])
  180. grid_size = median_span_height
  181. grid = collections.defaultdict(list)
  182. for i, span in enumerate(spans):
  183. start_cell = int(span['bbox'][1] / grid_size)
  184. end_cell = int(span['bbox'][3] / grid_size)
  185. for cell_idx in range(start_cell, end_cell + 1):
  186. grid[cell_idx].append(i)
  187. for char in all_chars:
  188. char_center_y = (char['bbox'][1] + char['bbox'][3]) / 2
  189. cell_idx = int(char_center_y / grid_size)
  190. candidate_span_indices = grid.get(cell_idx, [])
  191. for span_idx in candidate_span_indices:
  192. span = spans[span_idx]
  193. if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
  194. span['chars'].append(char)
  195. break
  196. need_ocr_spans = []
  197. for span in spans:
  198. chars_to_content(span)
  199. # 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
  200. if len(span['content']) * span['height'] < span['width'] * 0.5:
  201. # logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
  202. need_ocr_spans.append(span)
  203. del span['height'], span['width']
  204. return need_ocr_spans
  205. LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',', '-', '—', '–',)
  206. LINE_START_FLAG = ('(', '(', '"', '“', '【', '{', '《', '<', '「', '『', '【', '[',)
  207. Span_Height_Radio = 0.33 # 字符的中轴和span的中轴高度差不能超过1/3span高度
  208. def calculate_char_in_span(char_bbox, span_bbox, char, span_height_radio=Span_Height_Radio):
  209. char_center_x = (char_bbox[0] + char_bbox[2]) / 2
  210. char_center_y = (char_bbox[1] + char_bbox[3]) / 2
  211. span_center_y = (span_bbox[1] + span_bbox[3]) / 2
  212. span_height = span_bbox[3] - span_bbox[1]
  213. if (
  214. span_bbox[0] < char_center_x < span_bbox[2]
  215. and span_bbox[1] < char_center_y < span_bbox[3]
  216. and abs(char_center_y - span_center_y) < span_height * span_height_radio # 字符的中轴和span的中轴高度差不能超过Span_Height_Radio
  217. ):
  218. return True
  219. else:
  220. # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
  221. # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
  222. if char in LINE_STOP_FLAG:
  223. if (
  224. (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
  225. and char_center_x > span_bbox[0]
  226. and span_bbox[1] < char_center_y < span_bbox[3]
  227. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  228. ):
  229. return True
  230. elif char in LINE_START_FLAG:
  231. if (
  232. span_bbox[0] < char_bbox[2] < (span_bbox[0] + span_height)
  233. and char_center_x < span_bbox[2]
  234. and span_bbox[1] < char_center_y < span_bbox[3]
  235. and abs(char_center_y - span_center_y) < span_height * span_height_radio
  236. ):
  237. return True
  238. else:
  239. return False
  240. def chars_to_content(span):
  241. # 检查span中的char是否为空
  242. if len(span['chars']) == 0:
  243. pass
  244. else:
  245. # 给chars按char_idx排序
  246. span['chars'] = sorted(span['chars'], key=lambda x: x['char_idx'])
  247. # Calculate the width of each character
  248. char_widths = [char['bbox'][2] - char['bbox'][0] for char in span['chars']]
  249. # Calculate the median width
  250. median_width = statistics.median(char_widths)
  251. content = ''
  252. for char in span['chars']:
  253. # 如果下一个char的x0和上一个char的x1距离超过0.25个字符宽度,则需要在中间插入一个空格
  254. char1 = char
  255. char2 = span['chars'][span['chars'].index(char) + 1] if span['chars'].index(char) + 1 < len(span['chars']) else None
  256. if char2 and char2['bbox'][0] - char1['bbox'][2] > median_width * 0.25 and char['char'] != ' ' and char2['char'] != ' ':
  257. content += f"{char['char']} "
  258. else:
  259. content += char['char']
  260. content = __replace_unicode(content)
  261. content = __replace_ligatures(content)
  262. content = __replace_ligatures(content)
  263. span['content'] = content.strip()
  264. del span['chars']
  265. def calculate_contrast(img, img_mode) -> float:
  266. """
  267. 计算给定图像的对比度。
  268. :param img: 图像,类型为numpy.ndarray
  269. :Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
  270. :return: 图像的对比度值
  271. """
  272. if img_mode == 'rgb':
  273. # 将RGB图像转换为灰度图
  274. gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  275. elif img_mode == 'bgr':
  276. # 将BGR图像转换为灰度图
  277. gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  278. else:
  279. raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")
  280. # 计算均值和标准差
  281. mean_value = np.mean(gray_img)
  282. std_dev = np.std(gray_img)
  283. # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
  284. contrast = std_dev / (mean_value + 1e-6)
  285. # logger.debug(f"contrast: {contrast}")
  286. return round(contrast, 2)