hybrid_magic_model.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. import re
  2. from typing import Literal
  3. from loguru import logger
  4. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  5. from mineru.utils.enum_class import ContentType, BlockType, NotExtractType
  6. from mineru.utils.guess_suffix_or_lang import guess_language_by_text
  7. from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_index
  8. from mineru.utils.span_block_fix import fix_text_block
  9. from mineru.utils.span_pre_proc import txt_spans_extract
  10. not_extract_list = [item.value for item in NotExtractType]
  11. class MagicModel:
  12. def __init__(self,
  13. page_blocks: list,
  14. page_inline_formula,
  15. page_ocr_res,
  16. page,
  17. scale,
  18. page_pil_img,
  19. width,
  20. height,
  21. _ocr_enable,
  22. _vlm_ocr_enable,
  23. ):
  24. self.page_blocks = page_blocks
  25. self.page_inline_formula = page_inline_formula
  26. self.page_ocr_res = page_ocr_res
  27. self.width = width
  28. self.height = height
  29. blocks = []
  30. self.all_spans = []
  31. page_text_inline_formula_spans = []
  32. if not _vlm_ocr_enable:
  33. for inline_formula in page_inline_formula:
  34. inline_formula["bbox"] = self.cal_real_bbox(inline_formula["bbox"])
  35. inline_formula_latex = inline_formula.pop("latex", "")
  36. if inline_formula_latex:
  37. page_text_inline_formula_spans.append({
  38. "bbox": inline_formula["bbox"],
  39. "type": ContentType.INLINE_EQUATION,
  40. "content": inline_formula_latex,
  41. "score": inline_formula["score"],
  42. })
  43. for ocr_res in page_ocr_res:
  44. ocr_res["bbox"] = self.cal_real_bbox(ocr_res["bbox"])
  45. if ocr_res['category_id'] == 15:
  46. page_text_inline_formula_spans.append({
  47. "bbox": ocr_res["bbox"],
  48. "type": ContentType.TEXT,
  49. "content": ocr_res["text"],
  50. "score": ocr_res["score"],
  51. })
  52. if not _ocr_enable:
  53. virtual_block = [0, 0, width, height, None, None, None, "text"]
  54. page_text_inline_formula_spans = txt_spans_extract(page, page_text_inline_formula_spans, page_pil_img, scale, [virtual_block],[])
  55. # 解析每个块
  56. for index, block_info in enumerate(page_blocks):
  57. try:
  58. block_bbox = self.cal_real_bbox(block_info["bbox"])
  59. block_type = block_info["type"]
  60. block_content = block_info["content"]
  61. block_angle = block_info["angle"]
  62. # print(f"坐标: {block_bbox}")
  63. # print(f"类型: {block_type}")
  64. # print(f"内容: {block_content}")
  65. # print("-" * 50)
  66. except Exception as e:
  67. # 如果解析失败,可能是因为格式不正确,跳过这个块
  68. logger.warning(f"Invalid block format: {block_info}, error: {e}")
  69. continue
  70. span_type = "unknown"
  71. code_block_sub_type = None
  72. guess_lang = None
  73. if block_type in [
  74. "text",
  75. "title",
  76. "image_caption",
  77. "image_footnote",
  78. "table_caption",
  79. "table_footnote",
  80. "code_caption",
  81. "ref_text",
  82. "phonetic",
  83. "header",
  84. "footer",
  85. "page_number",
  86. "aside_text",
  87. "page_footnote",
  88. "list"
  89. ]:
  90. span_type = ContentType.TEXT
  91. elif block_type in ["image"]:
  92. block_type = BlockType.IMAGE_BODY
  93. span_type = ContentType.IMAGE
  94. elif block_type in ["table"]:
  95. block_type = BlockType.TABLE_BODY
  96. span_type = ContentType.TABLE
  97. elif block_type in ["code", "algorithm"]:
  98. block_content = code_content_clean(block_content)
  99. code_block_sub_type = block_type
  100. block_type = BlockType.CODE_BODY
  101. span_type = ContentType.TEXT
  102. guess_lang = guess_language_by_text(block_content)
  103. elif block_type in ["equation"]:
  104. block_type = BlockType.INTERLINE_EQUATION
  105. span_type = ContentType.INTERLINE_EQUATION
  106. # code 和 algorithm 类型的块,如果内容中包含行内公式,则需要将块类型切换为algorithm
  107. switch_code_to_algorithm = False
  108. span = None
  109. if span_type in ["image", "table"]:
  110. span = {
  111. "bbox": block_bbox,
  112. "type": span_type,
  113. }
  114. if span_type == ContentType.TABLE:
  115. span["html"] = block_content
  116. elif span_type in [ContentType.INTERLINE_EQUATION]:
  117. span = {
  118. "bbox": block_bbox,
  119. "type": span_type,
  120. "content": isolated_formula_clean(block_content),
  121. }
  122. elif _vlm_ocr_enable or block_type not in not_extract_list:
  123. # vlm_ocr_enable模式下,所有文本块都直接使用block的内容
  124. # 非vlm_ocr_enable模式下,非提取块需要使用span填充方式
  125. if block_content:
  126. block_content = clean_content(block_content)
  127. if block_content and block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
  128. switch_code_to_algorithm = True
  129. # 生成包含文本和公式的span列表
  130. spans = []
  131. last_end = 0
  132. # 查找所有公式
  133. for match in re.finditer(r'\\\((.+?)\\\)', block_content):
  134. start, end = match.span()
  135. # 添加公式前的文本
  136. if start > last_end:
  137. text_before = block_content[last_end:start]
  138. if text_before.strip():
  139. spans.append({
  140. "bbox": block_bbox,
  141. "type": ContentType.TEXT,
  142. "content": text_before
  143. })
  144. # 添加公式(去除\(和\))
  145. formula = match.group(1)
  146. spans.append({
  147. "bbox": block_bbox,
  148. "type": ContentType.INLINE_EQUATION,
  149. "content": formula.strip()
  150. })
  151. last_end = end
  152. # 添加最后一个公式后的文本
  153. if last_end < len(block_content):
  154. text_after = block_content[last_end:]
  155. if text_after.strip():
  156. spans.append({
  157. "bbox": block_bbox,
  158. "type": ContentType.TEXT,
  159. "content": text_after
  160. })
  161. span = spans
  162. else:
  163. span = {
  164. "bbox": block_bbox,
  165. "type": span_type,
  166. "content": block_content,
  167. }
  168. if (
  169. span_type in ["image", "table", ContentType.INTERLINE_EQUATION]
  170. or (_vlm_ocr_enable or block_type not in not_extract_list)
  171. ):
  172. if span is None:
  173. continue
  174. # 处理span类型并添加到all_spans
  175. if isinstance(span, dict) and "bbox" in span:
  176. self.all_spans.append(span)
  177. spans = [span]
  178. elif isinstance(span, list):
  179. self.all_spans.extend(span)
  180. spans = span
  181. else:
  182. raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
  183. # 构造line对象
  184. if block_type in [BlockType.CODE_BODY]:
  185. if switch_code_to_algorithm and code_block_sub_type == "code":
  186. code_block_sub_type = "algorithm"
  187. line = {"bbox": block_bbox, "spans": spans,
  188. "extra": {"type": code_block_sub_type, "guess_lang": guess_lang}}
  189. else:
  190. line = {"bbox": block_bbox, "spans": spans}
  191. block = {
  192. "bbox": block_bbox,
  193. "type": block_type,
  194. "angle": block_angle,
  195. "lines": [line],
  196. "index": index,
  197. }
  198. else: # 使用span填充方式
  199. block_spans = []
  200. for span in page_text_inline_formula_spans:
  201. if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block_bbox) > 0.5:
  202. block_spans.append(span)
  203. # 从spans删除已经放入block_spans中的span
  204. if len(block_spans) > 0:
  205. for span in block_spans:
  206. page_text_inline_formula_spans.remove(span)
  207. block = {
  208. "bbox": block_bbox,
  209. "type": block_type,
  210. "angle": block_angle,
  211. "spans": block_spans,
  212. "index": index,
  213. }
  214. block = fix_text_block(block)
  215. blocks.append(block)
  216. self.image_blocks = []
  217. self.table_blocks = []
  218. self.interline_equation_blocks = []
  219. self.text_blocks = []
  220. self.title_blocks = []
  221. self.code_blocks = []
  222. self.discarded_blocks = []
  223. self.ref_text_blocks = []
  224. self.phonetic_blocks = []
  225. self.list_blocks = []
  226. for block in blocks:
  227. if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
  228. self.image_blocks.append(block)
  229. elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
  230. self.table_blocks.append(block)
  231. elif block["type"] in [BlockType.CODE_BODY, BlockType.CODE_CAPTION]:
  232. self.code_blocks.append(block)
  233. elif block["type"] == BlockType.INTERLINE_EQUATION:
  234. self.interline_equation_blocks.append(block)
  235. elif block["type"] == BlockType.TEXT:
  236. self.text_blocks.append(block)
  237. elif block["type"] == BlockType.TITLE:
  238. self.title_blocks.append(block)
  239. elif block["type"] in [BlockType.REF_TEXT]:
  240. self.ref_text_blocks.append(block)
  241. elif block["type"] in [BlockType.PHONETIC]:
  242. self.phonetic_blocks.append(block)
  243. elif block["type"] in [BlockType.HEADER, BlockType.FOOTER, BlockType.PAGE_NUMBER, BlockType.ASIDE_TEXT, BlockType.PAGE_FOOTNOTE]:
  244. self.discarded_blocks.append(block)
  245. elif block["type"] == BlockType.LIST:
  246. self.list_blocks.append(block)
  247. else:
  248. continue
  249. self.list_blocks, self.text_blocks, self.ref_text_blocks = fix_list_blocks(self.list_blocks, self.text_blocks, self.ref_text_blocks)
  250. self.image_blocks, not_include_image_blocks = fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
  251. self.table_blocks, not_include_table_blocks = fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
  252. self.code_blocks, not_include_code_blocks = fix_two_layer_blocks(self.code_blocks, BlockType.CODE)
  253. for code_block in self.code_blocks:
  254. for block in code_block['blocks']:
  255. if block['type'] == BlockType.CODE_BODY:
  256. if len(block["lines"]) > 0:
  257. line = block["lines"][0]
  258. code_block["sub_type"] = line["extra"]["type"]
  259. if code_block["sub_type"] in ["code"]:
  260. code_block["guess_lang"] = line["extra"]["guess_lang"]
  261. del line["extra"]
  262. else:
  263. code_block["sub_type"] = "code"
  264. code_block["guess_lang"] = "txt"
  265. for block in not_include_image_blocks + not_include_table_blocks + not_include_code_blocks:
  266. block["type"] = BlockType.TEXT
  267. self.text_blocks.append(block)
  268. def cal_real_bbox(self, bbox):
  269. x1, y1, x2, y2 = bbox
  270. x_1, y_1, x_2, y_2 = (
  271. int(x1 * self.width),
  272. int(y1 * self.height),
  273. int(x2 * self.width),
  274. int(y2 * self.height),
  275. )
  276. if x_2 < x_1:
  277. x_1, x_2 = x_2, x_1
  278. if y_2 < y_1:
  279. y_1, y_2 = y_2, y_1
  280. bbox = (x_1, y_1, x_2, y_2)
  281. return bbox
  282. def get_list_blocks(self):
  283. return self.list_blocks
  284. def get_image_blocks(self):
  285. return self.image_blocks
  286. def get_table_blocks(self):
  287. return self.table_blocks
  288. def get_code_blocks(self):
  289. return self.code_blocks
  290. def get_ref_text_blocks(self):
  291. return self.ref_text_blocks
  292. def get_phonetic_blocks(self):
  293. return self.phonetic_blocks
  294. def get_title_blocks(self):
  295. return self.title_blocks
  296. def get_text_blocks(self):
  297. return self.text_blocks
  298. def get_interline_equation_blocks(self):
  299. return self.interline_equation_blocks
  300. def get_discarded_blocks(self):
  301. return self.discarded_blocks
  302. def get_all_spans(self):
  303. return self.all_spans
  304. def isolated_formula_clean(txt):
  305. latex = txt[:]
  306. if latex.startswith("\\["): latex = latex[2:]
  307. if latex.endswith("\\]"): latex = latex[:-2]
  308. latex = latex.strip()
  309. return latex
  310. def code_content_clean(content):
  311. """清理代码内容,移除Markdown代码块的开始和结束标记"""
  312. if not content:
  313. return ""
  314. lines = content.splitlines()
  315. start_idx = 0
  316. end_idx = len(lines)
  317. # 处理开头的三个反引号
  318. if lines and lines[0].startswith("```"):
  319. start_idx = 1
  320. # 处理结尾的三个反引号
  321. if lines and end_idx > start_idx and lines[end_idx - 1].strip() == "```":
  322. end_idx -= 1
  323. # 只有在有内容时才进行join操作
  324. if start_idx < end_idx:
  325. return "\n".join(lines[start_idx:end_idx]).strip()
  326. return ""
  327. def clean_content(content):
  328. if content and content.count("\\[") == content.count("\\]") and content.count("\\[") > 0:
  329. # Function to handle each match
  330. def replace_pattern(match):
  331. # Extract content between \[ and \]
  332. inner_content = match.group(1)
  333. return f"[{inner_content}]"
  334. # Find all patterns of \[x\] and apply replacement
  335. pattern = r'\\\[(.*?)\\\]'
  336. content = re.sub(pattern, replace_pattern, content)
  337. return content
  338. def __tie_up_category_by_index(blocks, subject_block_type, object_block_type):
  339. """基于index的主客体关联包装函数"""
  340. # 定义获取主体和客体对象的函数
  341. def get_subjects():
  342. return reduct_overlap(
  343. list(
  344. map(
  345. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle": x["angle"]},
  346. filter(
  347. lambda x: x["type"] == subject_block_type,
  348. blocks,
  349. ),
  350. )
  351. )
  352. )
  353. def get_objects():
  354. return reduct_overlap(
  355. list(
  356. map(
  357. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle": x["angle"]},
  358. filter(
  359. lambda x: x["type"] == object_block_type,
  360. blocks,
  361. ),
  362. )
  363. )
  364. )
  365. # 调用通用方法
  366. return tie_up_category_by_index(
  367. get_subjects,
  368. get_objects,
  369. object_block_type=object_block_type
  370. )
  371. def get_type_blocks_by_index(blocks, block_type: Literal["image", "table", "code"]):
  372. """使用基于index的匹配策略来组织blocks"""
  373. with_captions = __tie_up_category_by_index(blocks, f"{block_type}_body", f"{block_type}_caption")
  374. with_footnotes = __tie_up_category_by_index(blocks, f"{block_type}_body", f"{block_type}_footnote")
  375. ret = []
  376. for v in with_captions:
  377. record = {
  378. f"{block_type}_body": v["sub_bbox"],
  379. f"{block_type}_caption_list": v["obj_bboxes"],
  380. }
  381. filter_idx = v["sub_idx"]
  382. d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
  383. record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
  384. ret.append(record)
  385. return ret
  386. def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table", "code"]):
  387. need_fix_blocks = get_type_blocks_by_index(blocks, fix_type)
  388. fixed_blocks = []
  389. not_include_blocks = []
  390. processed_indices = set()
  391. # 特殊处理表格类型,确保标题在表格前,注脚在表格后
  392. if fix_type in ["table", "image"]:
  393. # 收集所有不合适的caption和footnote
  394. misplaced_footnotes = [] # 存储(footnote, 原始block索引)
  395. # 第一步:移除不符合位置要求的footnote
  396. for block_idx, block in enumerate(need_fix_blocks):
  397. body = block[f"{fix_type}_body"]
  398. body_index = body["index"]
  399. # 检查footnote应在body后或同位置
  400. valid_footnotes = []
  401. for footnote in block[f"{fix_type}_footnote_list"]:
  402. if footnote["index"] >= body_index:
  403. valid_footnotes.append(footnote)
  404. else:
  405. misplaced_footnotes.append((footnote, block_idx))
  406. block[f"{fix_type}_footnote_list"] = valid_footnotes
  407. # 第三步:重新分配不合规的footnote到合适的body
  408. for footnote, original_block_idx in misplaced_footnotes:
  409. footnote_index = footnote["index"]
  410. best_block_idx = None
  411. min_distance = float('inf')
  412. # 寻找索引小于等于footnote_index的最近body
  413. for idx, block in enumerate(need_fix_blocks):
  414. body_index = block[f"{fix_type}_body"]["index"]
  415. if body_index <= footnote_index and idx != original_block_idx:
  416. distance = footnote_index - body_index
  417. if distance < min_distance:
  418. min_distance = distance
  419. best_block_idx = idx
  420. if best_block_idx is not None:
  421. # 找到合适的body,添加到对应block的footnote_list
  422. need_fix_blocks[best_block_idx][f"{fix_type}_footnote_list"].append(footnote)
  423. else:
  424. # 没找到合适的body,作为普通block处理
  425. not_include_blocks.append(footnote)
  426. # 第四步:将每个block的caption_list和footnote_list中不连续index的元素提出来作为普通block处理
  427. for block in need_fix_blocks:
  428. caption_list = block[f"{fix_type}_caption_list"]
  429. footnote_list = block[f"{fix_type}_footnote_list"]
  430. body_index = block[f"{fix_type}_body"]["index"]
  431. # 处理caption_list (从body往前看,caption在body之前)
  432. if caption_list:
  433. # 按index降序排列,从最接近body的开始检查
  434. caption_list.sort(key=lambda x: x["index"], reverse=True)
  435. filtered_captions = [caption_list[0]]
  436. for i in range(1, len(caption_list)):
  437. prev_index = caption_list[i - 1]["index"]
  438. curr_index = caption_list[i]["index"]
  439. # 检查是否连续
  440. if curr_index == prev_index - 1:
  441. filtered_captions.append(caption_list[i])
  442. else:
  443. # 检查gap中是否只有body_index
  444. gap_indices = set(range(curr_index + 1, prev_index))
  445. if gap_indices == {body_index}:
  446. # gap中只有body_index,不算真正的gap
  447. filtered_captions.append(caption_list[i])
  448. else:
  449. # 出现真正的gap,后续所有caption都作为普通block
  450. not_include_blocks.extend(caption_list[i:])
  451. break
  452. # 恢复升序
  453. filtered_captions.reverse()
  454. block[f"{fix_type}_caption_list"] = filtered_captions
  455. # 处理footnote_list (从body往后看,footnote在body之后)
  456. if footnote_list:
  457. # 按index升序排列,从最接近body的开始检查
  458. footnote_list.sort(key=lambda x: x["index"])
  459. filtered_footnotes = [footnote_list[0]]
  460. for i in range(1, len(footnote_list)):
  461. # 检查是否与前一个footnote连续
  462. if footnote_list[i]["index"] == footnote_list[i - 1]["index"] + 1:
  463. filtered_footnotes.append(footnote_list[i])
  464. else:
  465. # 出现gap,后续所有footnote都作为普通block
  466. not_include_blocks.extend(footnote_list[i:])
  467. break
  468. block[f"{fix_type}_footnote_list"] = filtered_footnotes
  469. # 构建两层结构blocks
  470. for block in need_fix_blocks:
  471. body = block[f"{fix_type}_body"]
  472. caption_list = block[f"{fix_type}_caption_list"]
  473. footnote_list = block[f"{fix_type}_footnote_list"]
  474. body["type"] = f"{fix_type}_body"
  475. for caption in caption_list:
  476. caption["type"] = f"{fix_type}_caption"
  477. processed_indices.add(caption["index"])
  478. for footnote in footnote_list:
  479. footnote["type"] = f"{fix_type}_footnote"
  480. processed_indices.add(footnote["index"])
  481. processed_indices.add(body["index"])
  482. two_layer_block = {
  483. "type": fix_type,
  484. "bbox": body["bbox"],
  485. "blocks": [body],
  486. "index": body["index"],
  487. }
  488. two_layer_block["blocks"].extend([*caption_list, *footnote_list])
  489. # 对blocks按index排序
  490. two_layer_block["blocks"].sort(key=lambda x: x["index"])
  491. fixed_blocks.append(two_layer_block)
  492. # 添加未处理的blocks
  493. for block in blocks:
  494. block.pop("type", None)
  495. if block["index"] not in processed_indices and block not in not_include_blocks:
  496. not_include_blocks.append(block)
  497. return fixed_blocks, not_include_blocks
  498. def fix_list_blocks(list_blocks, text_blocks, ref_text_blocks):
  499. for list_block in list_blocks:
  500. list_block["blocks"] = []
  501. if "lines" in list_block:
  502. del list_block["lines"]
  503. temp_text_blocks = text_blocks + ref_text_blocks
  504. need_remove_blocks = []
  505. for block in temp_text_blocks:
  506. for list_block in list_blocks:
  507. if calculate_overlap_area_in_bbox1_area_ratio(block["bbox"], list_block["bbox"]) >= 0.8:
  508. list_block["blocks"].append(block)
  509. need_remove_blocks.append(block)
  510. break
  511. for block in need_remove_blocks:
  512. if block in text_blocks:
  513. text_blocks.remove(block)
  514. elif block in ref_text_blocks:
  515. ref_text_blocks.remove(block)
  516. # 移除blocks为空的list_block
  517. list_blocks = [lb for lb in list_blocks if lb["blocks"]]
  518. for list_block in list_blocks:
  519. # 统计list_block["blocks"]中所有block的type,用众数作为list_block的sub_type
  520. type_count = {}
  521. for sub_block in list_block["blocks"]:
  522. sub_block_type = sub_block["type"]
  523. if sub_block_type not in type_count:
  524. type_count[sub_block_type] = 0
  525. type_count[sub_block_type] += 1
  526. if type_count:
  527. list_block["sub_type"] = max(type_count, key=type_count.get)
  528. else:
  529. list_block["sub_type"] = "unknown"
  530. return list_blocks, text_blocks, ref_text_blocks