vlm_magic_model.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. import re
  2. from typing import Literal
  3. from loguru import logger
  4. from mineru.utils.boxbase import calculate_overlap_area_in_bbox1_area_ratio
  5. from mineru.utils.enum_class import ContentType, BlockType
  6. from mineru.utils.guess_suffix_or_lang import guess_language_by_text
  7. from mineru.utils.magic_model_utils import reduct_overlap, tie_up_category_by_index
  8. class MagicModel:
  9. def __init__(self, page_blocks: list, width, height):
  10. self.page_blocks = page_blocks
  11. blocks = []
  12. self.all_spans = []
  13. # 解析每个块
  14. for index, block_info in enumerate(page_blocks):
  15. block_bbox = block_info["bbox"]
  16. try:
  17. x1, y1, x2, y2 = block_bbox
  18. x_1, y_1, x_2, y_2 = (
  19. int(x1 * width),
  20. int(y1 * height),
  21. int(x2 * width),
  22. int(y2 * height),
  23. )
  24. if x_2 < x_1:
  25. x_1, x_2 = x_2, x_1
  26. if y_2 < y_1:
  27. y_1, y_2 = y_2, y_1
  28. block_bbox = (x_1, y_1, x_2, y_2)
  29. block_type = block_info["type"]
  30. block_content = block_info["content"]
  31. block_angle = block_info["angle"]
  32. # print(f"坐标: {block_bbox}")
  33. # print(f"类型: {block_type}")
  34. # print(f"内容: {block_content}")
  35. # print("-" * 50)
  36. except Exception as e:
  37. # 如果解析失败,可能是因为格式不正确,跳过这个块
  38. logger.warning(f"Invalid block format: {block_info}, error: {e}")
  39. continue
  40. span_type = "unknown"
  41. code_block_sub_type = None
  42. guess_lang = None
  43. if block_type in [
  44. "text",
  45. "title",
  46. "image_caption",
  47. "image_footnote",
  48. "table_caption",
  49. "table_footnote",
  50. "code_caption",
  51. "ref_text",
  52. "phonetic",
  53. "header",
  54. "footer",
  55. "page_number",
  56. "aside_text",
  57. "page_footnote",
  58. "list"
  59. ]:
  60. span_type = ContentType.TEXT
  61. elif block_type in ["image"]:
  62. block_type = BlockType.IMAGE_BODY
  63. span_type = ContentType.IMAGE
  64. elif block_type in ["table"]:
  65. block_type = BlockType.TABLE_BODY
  66. span_type = ContentType.TABLE
  67. elif block_type in ["code", "algorithm"]:
  68. block_content = code_content_clean(block_content)
  69. code_block_sub_type = block_type
  70. block_type = BlockType.CODE_BODY
  71. span_type = ContentType.TEXT
  72. guess_lang = guess_language_by_text(block_content)
  73. elif block_type in ["equation"]:
  74. block_type = BlockType.INTERLINE_EQUATION
  75. span_type = ContentType.INTERLINE_EQUATION
  76. # code 和 algorithm 类型的块,如果内容中包含行内公式,则需要将块类型切换为algorithm
  77. switch_code_to_algorithm = False
  78. if span_type in ["image", "table"]:
  79. span = {
  80. "bbox": block_bbox,
  81. "type": span_type,
  82. }
  83. if span_type == ContentType.TABLE:
  84. span["html"] = block_content
  85. elif span_type in [ContentType.INTERLINE_EQUATION]:
  86. span = {
  87. "bbox": block_bbox,
  88. "type": span_type,
  89. "content": isolated_formula_clean(block_content),
  90. }
  91. else:
  92. if block_content:
  93. block_content = clean_content(block_content)
  94. if block_content and block_content.count("\\(") == block_content.count("\\)") and block_content.count("\\(") > 0:
  95. switch_code_to_algorithm = True
  96. # 生成包含文本和公式的span列表
  97. spans = []
  98. last_end = 0
  99. # 查找所有公式
  100. for match in re.finditer(r'\\\((.+?)\\\)', block_content):
  101. start, end = match.span()
  102. # 添加公式前的文本
  103. if start > last_end:
  104. text_before = block_content[last_end:start]
  105. if text_before.strip():
  106. spans.append({
  107. "bbox": block_bbox,
  108. "type": ContentType.TEXT,
  109. "content": text_before
  110. })
  111. # 添加公式(去除\(和\))
  112. formula = match.group(1)
  113. spans.append({
  114. "bbox": block_bbox,
  115. "type": ContentType.INLINE_EQUATION,
  116. "content": formula.strip()
  117. })
  118. last_end = end
  119. # 添加最后一个公式后的文本
  120. if last_end < len(block_content):
  121. text_after = block_content[last_end:]
  122. if text_after.strip():
  123. spans.append({
  124. "bbox": block_bbox,
  125. "type": ContentType.TEXT,
  126. "content": text_after
  127. })
  128. span = spans
  129. else:
  130. span = {
  131. "bbox": block_bbox,
  132. "type": span_type,
  133. "content": block_content,
  134. }
  135. # 处理span类型并添加到all_spans
  136. if isinstance(span, dict) and "bbox" in span:
  137. self.all_spans.append(span)
  138. spans = [span]
  139. elif isinstance(span, list):
  140. self.all_spans.extend(span)
  141. spans = span
  142. else:
  143. raise ValueError(f"Invalid span type: {span_type}, expected dict or list, got {type(span)}")
  144. # 构造line对象
  145. if block_type in [BlockType.CODE_BODY]:
  146. if switch_code_to_algorithm and code_block_sub_type == "code":
  147. code_block_sub_type = "algorithm"
  148. line = {"bbox": block_bbox, "spans": spans, "extra": {"type": code_block_sub_type, "guess_lang": guess_lang}}
  149. else:
  150. line = {"bbox": block_bbox, "spans": spans}
  151. blocks.append(
  152. {
  153. "bbox": block_bbox,
  154. "type": block_type,
  155. "angle": block_angle,
  156. "lines": [line],
  157. "index": index,
  158. }
  159. )
  160. self.image_blocks = []
  161. self.table_blocks = []
  162. self.interline_equation_blocks = []
  163. self.text_blocks = []
  164. self.title_blocks = []
  165. self.code_blocks = []
  166. self.discarded_blocks = []
  167. self.ref_text_blocks = []
  168. self.phonetic_blocks = []
  169. self.list_blocks = []
  170. for block in blocks:
  171. if block["type"] in [BlockType.IMAGE_BODY, BlockType.IMAGE_CAPTION, BlockType.IMAGE_FOOTNOTE]:
  172. self.image_blocks.append(block)
  173. elif block["type"] in [BlockType.TABLE_BODY, BlockType.TABLE_CAPTION, BlockType.TABLE_FOOTNOTE]:
  174. self.table_blocks.append(block)
  175. elif block["type"] in [BlockType.CODE_BODY, BlockType.CODE_CAPTION]:
  176. self.code_blocks.append(block)
  177. elif block["type"] == BlockType.INTERLINE_EQUATION:
  178. self.interline_equation_blocks.append(block)
  179. elif block["type"] == BlockType.TEXT:
  180. self.text_blocks.append(block)
  181. elif block["type"] == BlockType.TITLE:
  182. self.title_blocks.append(block)
  183. elif block["type"] in [BlockType.REF_TEXT]:
  184. self.ref_text_blocks.append(block)
  185. elif block["type"] in [BlockType.PHONETIC]:
  186. self.phonetic_blocks.append(block)
  187. elif block["type"] in [BlockType.HEADER, BlockType.FOOTER, BlockType.PAGE_NUMBER, BlockType.ASIDE_TEXT, BlockType.PAGE_FOOTNOTE]:
  188. self.discarded_blocks.append(block)
  189. elif block["type"] == BlockType.LIST:
  190. self.list_blocks.append(block)
  191. else:
  192. continue
  193. self.list_blocks, self.text_blocks, self.ref_text_blocks = fix_list_blocks(self.list_blocks, self.text_blocks, self.ref_text_blocks)
  194. self.image_blocks, not_include_image_blocks = fix_two_layer_blocks(self.image_blocks, BlockType.IMAGE)
  195. self.table_blocks, not_include_table_blocks = fix_two_layer_blocks(self.table_blocks, BlockType.TABLE)
  196. self.code_blocks, not_include_code_blocks = fix_two_layer_blocks(self.code_blocks, BlockType.CODE)
  197. for code_block in self.code_blocks:
  198. for block in code_block['blocks']:
  199. if block['type'] == BlockType.CODE_BODY:
  200. if len(block["lines"]) > 0:
  201. line = block["lines"][0]
  202. code_block["sub_type"] = line["extra"]["type"]
  203. if code_block["sub_type"] in ["code"]:
  204. code_block["guess_lang"] = line["extra"]["guess_lang"]
  205. del line["extra"]
  206. else:
  207. code_block["sub_type"] = "code"
  208. code_block["guess_lang"] = "txt"
  209. for block in not_include_image_blocks + not_include_table_blocks + not_include_code_blocks:
  210. block["type"] = BlockType.TEXT
  211. self.text_blocks.append(block)
  212. def get_list_blocks(self):
  213. return self.list_blocks
  214. def get_image_blocks(self):
  215. return self.image_blocks
  216. def get_table_blocks(self):
  217. return self.table_blocks
  218. def get_code_blocks(self):
  219. return self.code_blocks
  220. def get_ref_text_blocks(self):
  221. return self.ref_text_blocks
  222. def get_phonetic_blocks(self):
  223. return self.phonetic_blocks
  224. def get_title_blocks(self):
  225. return self.title_blocks
  226. def get_text_blocks(self):
  227. return self.text_blocks
  228. def get_interline_equation_blocks(self):
  229. return self.interline_equation_blocks
  230. def get_discarded_blocks(self):
  231. return self.discarded_blocks
  232. def get_all_spans(self):
  233. return self.all_spans
  234. def isolated_formula_clean(txt):
  235. latex = txt[:]
  236. if latex.startswith("\\["): latex = latex[2:]
  237. if latex.endswith("\\]"): latex = latex[:-2]
  238. latex = latex.strip()
  239. return latex
  240. def code_content_clean(content):
  241. """清理代码内容,移除Markdown代码块的开始和结束标记"""
  242. if not content:
  243. return ""
  244. lines = content.splitlines()
  245. start_idx = 0
  246. end_idx = len(lines)
  247. # 处理开头的三个反引号
  248. if lines and lines[0].startswith("```"):
  249. start_idx = 1
  250. # 处理结尾的三个反引号
  251. if lines and end_idx > start_idx and lines[end_idx - 1].strip() == "```":
  252. end_idx -= 1
  253. # 只有在有内容时才进行join操作
  254. if start_idx < end_idx:
  255. return "\n".join(lines[start_idx:end_idx]).strip()
  256. return ""
  257. def clean_content(content):
  258. if content and content.count("\\[") == content.count("\\]") and content.count("\\[") > 0:
  259. # Function to handle each match
  260. def replace_pattern(match):
  261. # Extract content between \[ and \]
  262. inner_content = match.group(1)
  263. return f"[{inner_content}]"
  264. # Find all patterns of \[x\] and apply replacement
  265. pattern = r'\\\[(.*?)\\\]'
  266. content = re.sub(pattern, replace_pattern, content)
  267. return content
  268. def __tie_up_category_by_index(blocks, subject_block_type, object_block_type):
  269. """基于index的主客体关联包装函数"""
  270. # 定义获取主体和客体对象的函数
  271. def get_subjects():
  272. return reduct_overlap(
  273. list(
  274. map(
  275. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle": x["angle"]},
  276. filter(
  277. lambda x: x["type"] == subject_block_type,
  278. blocks,
  279. ),
  280. )
  281. )
  282. )
  283. def get_objects():
  284. return reduct_overlap(
  285. list(
  286. map(
  287. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"], "angle": x["angle"]},
  288. filter(
  289. lambda x: x["type"] == object_block_type,
  290. blocks,
  291. ),
  292. )
  293. )
  294. )
  295. # 调用通用方法
  296. return tie_up_category_by_index(
  297. get_subjects,
  298. get_objects,
  299. object_block_type=object_block_type
  300. )
  301. def get_type_blocks(blocks, block_type: Literal["image", "table", "code"]):
  302. with_captions = __tie_up_category_by_index(blocks, f"{block_type}_body", f"{block_type}_caption")
  303. with_footnotes = __tie_up_category_by_index(blocks, f"{block_type}_body", f"{block_type}_footnote")
  304. ret = []
  305. for v in with_captions:
  306. record = {
  307. f"{block_type}_body": v["sub_bbox"],
  308. f"{block_type}_caption_list": v["obj_bboxes"],
  309. }
  310. filter_idx = v["sub_idx"]
  311. d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
  312. record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
  313. ret.append(record)
  314. return ret
  315. def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table", "code"]):
  316. need_fix_blocks = get_type_blocks(blocks, fix_type)
  317. fixed_blocks = []
  318. not_include_blocks = []
  319. processed_indices = set()
  320. # 特殊处理表格类型,确保标题在表格前,注脚在表格后
  321. if fix_type in ["table", "image"]:
  322. # 收集所有不合适的caption和footnote
  323. misplaced_captions = [] # 存储(caption, 原始block索引)
  324. misplaced_footnotes = [] # 存储(footnote, 原始block索引)
  325. # 第一步:移除不符合位置要求的footnote
  326. for block_idx, block in enumerate(need_fix_blocks):
  327. body = block[f"{fix_type}_body"]
  328. body_index = body["index"]
  329. # 检查footnote应在body后或同位置
  330. valid_footnotes = []
  331. for footnote in block[f"{fix_type}_footnote_list"]:
  332. if footnote["index"] >= body_index:
  333. valid_footnotes.append(footnote)
  334. else:
  335. misplaced_footnotes.append((footnote, block_idx))
  336. block[f"{fix_type}_footnote_list"] = valid_footnotes
  337. # 第三步:重新分配不合规的footnote到合适的body
  338. for footnote, original_block_idx in misplaced_footnotes:
  339. footnote_index = footnote["index"]
  340. best_block_idx = None
  341. min_distance = float('inf')
  342. # 寻找索引小于等于footnote_index的最近body
  343. for idx, block in enumerate(need_fix_blocks):
  344. body_index = block[f"{fix_type}_body"]["index"]
  345. if body_index <= footnote_index and idx != original_block_idx:
  346. distance = footnote_index - body_index
  347. if distance < min_distance:
  348. min_distance = distance
  349. best_block_idx = idx
  350. if best_block_idx is not None:
  351. # 找到合适的body,添加到对应block的footnote_list
  352. need_fix_blocks[best_block_idx][f"{fix_type}_footnote_list"].append(footnote)
  353. else:
  354. # 没找到合适的body,作为普通block处理
  355. not_include_blocks.append(footnote)
  356. # 第四步:将每个block的caption_list和footnote_list中不连续index的元素提出来作为普通block处理
  357. for block in need_fix_blocks:
  358. caption_list = block[f"{fix_type}_caption_list"]
  359. footnote_list = block[f"{fix_type}_footnote_list"]
  360. body_index = block[f"{fix_type}_body"]["index"]
  361. # 处理caption_list (从body往前看,caption在body之前)
  362. if caption_list:
  363. # 按index降序排列,从最接近body的开始检查
  364. caption_list.sort(key=lambda x: x["index"], reverse=True)
  365. filtered_captions = [caption_list[0]]
  366. for i in range(1, len(caption_list)):
  367. prev_index = caption_list[i - 1]["index"]
  368. curr_index = caption_list[i]["index"]
  369. # 检查是否连续
  370. if curr_index == prev_index - 1:
  371. filtered_captions.append(caption_list[i])
  372. else:
  373. # 检查gap中是否只有body_index
  374. gap_indices = set(range(curr_index + 1, prev_index))
  375. if gap_indices == {body_index}:
  376. # gap中只有body_index,不算真正的gap
  377. filtered_captions.append(caption_list[i])
  378. else:
  379. # 出现真正的gap,后续所有caption都作为普通block
  380. not_include_blocks.extend(caption_list[i:])
  381. break
  382. # 恢复升序
  383. filtered_captions.reverse()
  384. block[f"{fix_type}_caption_list"] = filtered_captions
  385. # 处理footnote_list (从body往后看,footnote在body之后)
  386. if footnote_list:
  387. # 按index升序排列,从最接近body的开始检查
  388. footnote_list.sort(key=lambda x: x["index"])
  389. filtered_footnotes = [footnote_list[0]]
  390. for i in range(1, len(footnote_list)):
  391. # 检查是否与前一个footnote连续
  392. if footnote_list[i]["index"] == footnote_list[i - 1]["index"] + 1:
  393. filtered_footnotes.append(footnote_list[i])
  394. else:
  395. # 出现gap,后续所有footnote都作为普通block
  396. not_include_blocks.extend(footnote_list[i:])
  397. break
  398. block[f"{fix_type}_footnote_list"] = filtered_footnotes
  399. # 构建两层结构blocks
  400. for block in need_fix_blocks:
  401. body = block[f"{fix_type}_body"]
  402. caption_list = block[f"{fix_type}_caption_list"]
  403. footnote_list = block[f"{fix_type}_footnote_list"]
  404. body["type"] = f"{fix_type}_body"
  405. for caption in caption_list:
  406. caption["type"] = f"{fix_type}_caption"
  407. processed_indices.add(caption["index"])
  408. for footnote in footnote_list:
  409. footnote["type"] = f"{fix_type}_footnote"
  410. processed_indices.add(footnote["index"])
  411. processed_indices.add(body["index"])
  412. two_layer_block = {
  413. "type": fix_type,
  414. "bbox": body["bbox"],
  415. "blocks": [body],
  416. "index": body["index"],
  417. }
  418. two_layer_block["blocks"].extend([*caption_list, *footnote_list])
  419. # 对blocks按index排序
  420. two_layer_block["blocks"].sort(key=lambda x: x["index"])
  421. fixed_blocks.append(two_layer_block)
  422. # 添加未处理的blocks
  423. for block in blocks:
  424. block.pop("type", None)
  425. if block["index"] not in processed_indices and block not in not_include_blocks:
  426. not_include_blocks.append(block)
  427. return fixed_blocks, not_include_blocks
  428. def fix_list_blocks(list_blocks, text_blocks, ref_text_blocks):
  429. for list_block in list_blocks:
  430. list_block["blocks"] = []
  431. if "lines" in list_block:
  432. del list_block["lines"]
  433. temp_text_blocks = text_blocks + ref_text_blocks
  434. need_remove_blocks = []
  435. for block in temp_text_blocks:
  436. for list_block in list_blocks:
  437. if calculate_overlap_area_in_bbox1_area_ratio(block["bbox"], list_block["bbox"]) >= 0.8:
  438. list_block["blocks"].append(block)
  439. need_remove_blocks.append(block)
  440. break
  441. for block in need_remove_blocks:
  442. if block in text_blocks:
  443. text_blocks.remove(block)
  444. elif block in ref_text_blocks:
  445. ref_text_blocks.remove(block)
  446. # 移除blocks为空的list_block
  447. list_blocks = [lb for lb in list_blocks if lb["blocks"]]
  448. for list_block in list_blocks:
  449. # 统计list_block["blocks"]中所有block的type,用众数作为list_block的sub_type
  450. type_count = {}
  451. for sub_block in list_block["blocks"]:
  452. sub_block_type = sub_block["type"]
  453. if sub_block_type not in type_count:
  454. type_count[sub_block_type] = 0
  455. type_count[sub_block_type] += 1
  456. if type_count:
  457. list_block["sub_type"] = max(type_count, key=type_count.get)
  458. else:
  459. list_block["sub_type"] = "unknown"
  460. return list_blocks, text_blocks, ref_text_blocks