hybrid_analyze.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. from collections import defaultdict
  5. import cv2
  6. import numpy as np
  7. from loguru import logger
  8. from mineru_vl_utils import MinerUClient
  9. from mineru_vl_utils.structs import BlockType
  10. from tqdm import tqdm
  11. from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json
  12. from mineru.backend.pipeline.model_init import HybridModelSingleton
  13. from mineru.backend.vlm.vlm_analyze import ModelSingleton
  14. from mineru.data.data_reader_writer import DataWriter
  15. from mineru.utils.config_reader import get_device
  16. from mineru.utils.enum_class import ImageType, NotExtractType
  17. from mineru.utils.model_utils import crop_img, get_vram, clean_memory
  18. from mineru.utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, sorted_boxes, merge_det_boxes, \
  19. update_det_boxes, OcrConfidence
  20. from mineru.utils.pdf_classify import classify
  21. from mineru.utils.pdf_image_tools import load_images_from_pdf
  22. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
  23. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  24. MFR_BASE_BATCH_SIZE = 16
  25. OCR_DET_BASE_BATCH_SIZE = 16
  26. not_extract_list = [item.value for item in NotExtractType]
  27. def ocr_classify(pdf_bytes, parse_method: str = 'auto',) -> bool:
  28. # 确定OCR设置
  29. _ocr_enable = False
  30. if parse_method == 'auto':
  31. if classify(pdf_bytes) == 'ocr':
  32. _ocr_enable = True
  33. elif parse_method == 'ocr':
  34. _ocr_enable = True
  35. return _ocr_enable
  36. def ocr_det(
  37. hybrid_pipeline_model,
  38. np_images,
  39. results,
  40. mfd_res,
  41. _ocr_enable,
  42. batch_radio: int = 1,
  43. ):
  44. ocr_res_list = []
  45. if not hybrid_pipeline_model.enable_ocr_det_batch:
  46. # 非批处理模式 - 逐页处理
  47. for np_image, page_mfd_res, page_results in tqdm(
  48. zip(np_images, mfd_res, results),
  49. total=len(np_images),
  50. desc="OCR-det"
  51. ):
  52. ocr_res_list.append([])
  53. img_height, img_width = np_image.shape[:2]
  54. for res in page_results:
  55. if res['type'] not in not_extract_list:
  56. continue
  57. x0 = max(0, int(res['bbox'][0] * img_width))
  58. y0 = max(0, int(res['bbox'][1] * img_height))
  59. x1 = min(img_width, int(res['bbox'][2] * img_width))
  60. y1 = min(img_height, int(res['bbox'][3] * img_height))
  61. if x1 <= x0 or y1 <= y0:
  62. continue
  63. res['poly'] = [x0, y0, x1, y0, x1, y1, x0, y1]
  64. new_image, useful_list = crop_img(
  65. res, np_image, crop_paste_x=50, crop_paste_y=50
  66. )
  67. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
  68. page_mfd_res, useful_list
  69. )
  70. bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
  71. ocr_res = hybrid_pipeline_model.ocr_model.ocr(
  72. bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
  73. )[0]
  74. if ocr_res:
  75. ocr_result_list = get_ocr_result_list(
  76. ocr_res, useful_list, _ocr_enable, bgr_image, hybrid_pipeline_model.lang
  77. )
  78. ocr_res_list[-1].extend(ocr_result_list)
  79. else:
  80. # 批处理模式 - 按语言和分辨率分组
  81. # 收集所有需要OCR检测的裁剪图像
  82. all_cropped_images_info = []
  83. for np_image, page_mfd_res, page_results in zip(
  84. np_images, mfd_res, results
  85. ):
  86. ocr_res_list.append([])
  87. img_height, img_width = np_image.shape[:2]
  88. for res in page_results:
  89. if res['type'] not in not_extract_list:
  90. continue
  91. x0 = max(0, int(res['bbox'][0] * img_width))
  92. y0 = max(0, int(res['bbox'][1] * img_height))
  93. x1 = min(img_width, int(res['bbox'][2] * img_width))
  94. y1 = min(img_height, int(res['bbox'][3] * img_height))
  95. if x1 <= x0 or y1 <= y0:
  96. continue
  97. res['poly'] = [x0, y0, x1, y0, x1, y1, x0, y1]
  98. new_image, useful_list = crop_img(
  99. res, np_image, crop_paste_x=50, crop_paste_y=50
  100. )
  101. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
  102. page_mfd_res, useful_list
  103. )
  104. bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
  105. all_cropped_images_info.append((
  106. bgr_image, useful_list, adjusted_mfdetrec_res, ocr_res_list[-1]
  107. ))
  108. # 按分辨率分组并同时完成padding
  109. RESOLUTION_GROUP_STRIDE = 64 # 32
  110. resolution_groups = defaultdict(list)
  111. for crop_info in all_cropped_images_info:
  112. cropped_img = crop_info[0]
  113. h, w = cropped_img.shape[:2]
  114. # 直接计算目标尺寸并用作分组键
  115. target_h = ((h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
  116. target_w = ((w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
  117. group_key = (target_h, target_w)
  118. resolution_groups[group_key].append(crop_info)
  119. # 对每个分辨率组进行批处理
  120. for (target_h, target_w), group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det"):
  121. # 对所有图像进行padding到统一尺寸
  122. batch_images = []
  123. for crop_info in group_crops:
  124. img = crop_info[0]
  125. h, w = img.shape[:2]
  126. # 创建目标尺寸的白色背景
  127. padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
  128. padded_img[:h, :w] = img
  129. batch_images.append(padded_img)
  130. # 批处理检测
  131. det_batch_size = min(len(batch_images), batch_radio*OCR_DET_BASE_BATCH_SIZE)
  132. batch_results = hybrid_pipeline_model.ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
  133. # 处理批处理结果
  134. for crop_info, (dt_boxes, _) in zip(group_crops, batch_results):
  135. bgr_image, useful_list, adjusted_mfdetrec_res, ocr_page_res_list = crop_info
  136. if dt_boxes is not None and len(dt_boxes) > 0:
  137. # 处理检测框
  138. dt_boxes_sorted = sorted_boxes(dt_boxes)
  139. dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) if dt_boxes_sorted else []
  140. # 根据公式位置更新检测框
  141. dt_boxes_final = (update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
  142. if dt_boxes_merged and adjusted_mfdetrec_res
  143. else dt_boxes_merged)
  144. if dt_boxes_final:
  145. ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
  146. ocr_result_list = get_ocr_result_list(
  147. ocr_res, useful_list, _ocr_enable, bgr_image, hybrid_pipeline_model.lang
  148. )
  149. ocr_page_res_list.extend(ocr_result_list)
  150. return ocr_res_list
  151. def mask_image_regions(np_images, results):
  152. # 根据vlm返回的结果,在每一页中将image、table、equation块mask成白色背景图像
  153. for np_image, vlm_page_results in zip(np_images, results):
  154. img_height, img_width = np_image.shape[:2]
  155. # 收集需要mask的区域
  156. mask_regions = []
  157. for block in vlm_page_results:
  158. if block['type'] in [BlockType.IMAGE, BlockType.TABLE, BlockType.EQUATION]:
  159. bbox = block['bbox']
  160. # 批量转换归一化坐标到像素坐标,并进行边界检查
  161. x0 = max(0, int(bbox[0] * img_width))
  162. y0 = max(0, int(bbox[1] * img_height))
  163. x1 = min(img_width, int(bbox[2] * img_width))
  164. y1 = min(img_height, int(bbox[3] * img_height))
  165. # 只添加有效区域
  166. if x1 > x0 and y1 > y0:
  167. mask_regions.append((y0, y1, x0, x1))
  168. # 批量应用mask
  169. for y0, y1, x0, x1 in mask_regions:
  170. np_image[y0:y1, x0:x1, :] = 255
  171. return np_images
  172. def normalize_poly_to_bbox(item, page_width, page_height):
  173. """将poly坐标归一化为bbox"""
  174. poly = item['poly']
  175. x0 = min(max(poly[0] / page_width, 0), 1)
  176. y0 = min(max(poly[1] / page_height, 0), 1)
  177. x1 = min(max(poly[4] / page_width, 0), 1)
  178. y1 = min(max(poly[5] / page_height, 0), 1)
  179. item['bbox'] = [round(x0, 3), round(y0, 3), round(x1, 3), round(y1, 3)]
  180. item.pop('poly', None)
  181. def _process_ocr_and_formulas(
  182. images_pil_list,
  183. results,
  184. language,
  185. inline_formula_enable,
  186. _ocr_enable,
  187. batch_radio: int = 1,
  188. ):
  189. """处理OCR和公式识别"""
  190. # 遍历results,对文本块截图交由OCR识别
  191. # 根据_ocr_enable决定ocr只开det还是det+rec
  192. # 根据inline_formula_enable决定是使用mfd和ocr结合的方式,还是纯ocr方式
  193. # 将PIL图片转换为numpy数组
  194. np_images = [np.asarray(pil_image).copy() for pil_image in images_pil_list]
  195. # 获取混合模型实例
  196. hybrid_model_singleton = HybridModelSingleton()
  197. hybrid_pipeline_model = hybrid_model_singleton.get_model(
  198. lang=language,
  199. formula_enable=inline_formula_enable,
  200. )
  201. if inline_formula_enable:
  202. # 在进行`行内`公式检测和识别前,先将图像中的图片、表格、`行间`公式区域mask掉
  203. np_images = mask_image_regions(np_images, results)
  204. # 公式检测
  205. images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, batch_size=1, conf=0.5)
  206. # 公式识别
  207. inline_formula_list = hybrid_pipeline_model.mfr_model.batch_predict(
  208. images_mfd_res,
  209. np_images,
  210. batch_size=batch_radio*MFR_BASE_BATCH_SIZE,
  211. interline_enable=True,
  212. )
  213. else:
  214. inline_formula_list = [[] for _ in range(len(images_pil_list))]
  215. mfd_res = []
  216. for page_inline_formula_list in inline_formula_list:
  217. page_mfd_res = []
  218. for formula in page_inline_formula_list:
  219. formula['category_id'] = 13
  220. page_mfd_res.append({
  221. "bbox": [int(formula['poly'][0]), int(formula['poly'][1]),
  222. int(formula['poly'][4]), int(formula['poly'][5])],
  223. })
  224. mfd_res.append(page_mfd_res)
  225. # vlm没有执行ocr,需要ocr_det
  226. ocr_res_list = ocr_det(
  227. hybrid_pipeline_model,
  228. np_images,
  229. results,
  230. mfd_res,
  231. _ocr_enable,
  232. batch_radio=batch_radio,
  233. )
  234. # 如果需要ocr则做ocr_rec
  235. if _ocr_enable:
  236. need_ocr_list = []
  237. img_crop_list = []
  238. for page_ocr_res_list in ocr_res_list:
  239. for ocr_res in page_ocr_res_list:
  240. if 'np_img' in ocr_res:
  241. need_ocr_list.append(ocr_res)
  242. img_crop_list.append(ocr_res.pop('np_img'))
  243. if len(img_crop_list) > 0:
  244. # Process OCR
  245. ocr_result_list = hybrid_pipeline_model.ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
  246. # Verify we have matching counts
  247. assert len(ocr_result_list) == len(need_ocr_list), f'ocr_result_list: {len(ocr_result_list)}, need_ocr_list: {len(need_ocr_list)}'
  248. # Process OCR results for this language
  249. for index, need_ocr_res in enumerate(need_ocr_list):
  250. ocr_text, ocr_score = ocr_result_list[index]
  251. need_ocr_res['text'] = ocr_text
  252. need_ocr_res['score'] = float(f"{ocr_score:.3f}")
  253. if ocr_score < OcrConfidence.min_confidence:
  254. need_ocr_res['category_id'] = 16
  255. else:
  256. layout_res_bbox = [need_ocr_res['poly'][0], need_ocr_res['poly'][1],
  257. need_ocr_res['poly'][4], need_ocr_res['poly'][5]]
  258. layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
  259. layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
  260. if (
  261. ocr_text in [
  262. '(204号', '(20', '(2', '(2号', '(20号', '号','(204',
  263. '(cid:)', '(ci:)', '(cd:1)', 'cd:)', 'c)', '(cd:)', 'c', 'id:)',
  264. ':)', '√:)', '√i:)', '−i:)', '−:' , 'i:)',
  265. ]
  266. and ocr_score < 0.8
  267. and layout_res_width < layout_res_height
  268. ):
  269. need_ocr_res['category_id'] = 16
  270. return inline_formula_list, ocr_res_list, hybrid_pipeline_model
  271. def _normalize_bbox(
  272. inline_formula_list,
  273. ocr_res_list,
  274. images_pil_list,
  275. ):
  276. """归一化坐标并生成最终结果"""
  277. for page_inline_formula_list, page_ocr_res_list, page_pil_image in zip(
  278. inline_formula_list, ocr_res_list, images_pil_list
  279. ):
  280. if page_inline_formula_list or page_ocr_res_list:
  281. page_width, page_height = page_pil_image.size
  282. # 处理公式列表
  283. for formula in page_inline_formula_list:
  284. normalize_poly_to_bbox(formula, page_width, page_height)
  285. # 处理OCR结果列表
  286. for ocr_res in page_ocr_res_list:
  287. normalize_poly_to_bbox(ocr_res, page_width, page_height)
  288. def get_batch_ratio(device):
  289. """
  290. 根据显存大小或环境变量获取 batch ratio
  291. """
  292. # 1. 优先尝试从环境变量获取
  293. """
  294. c/s架构分离部署时,建议通过设置环境变量 MINERU_HYBRID_BATCH_RATIO 来指定 batch ratio
  295. 建议的设置值如如下,以下配置值已考虑一定的冗余,单卡多终端部署时为了保证稳定性,可以额外保留一个client端的显存作为整体冗余
  296. 单个client端显存大小 | MINERU_HYBRID_BATCH_RATIO
  297. ------------------|------------------------
  298. <= 6 GB | 8
  299. <= 4.5 GB | 4
  300. <= 3 GB | 2
  301. <= 2.5 GB | 1
  302. 例如:
  303. export MINERU_HYBRID_BATCH_RATIO=4
  304. """
  305. env_val = os.getenv("MINERU_HYBRID_BATCH_RATIO")
  306. if env_val:
  307. try:
  308. batch_ratio = int(env_val)
  309. logger.info(f"hybrid batch ratio (from env): {batch_ratio}")
  310. return batch_ratio
  311. except ValueError as e:
  312. logger.warning(f"Invalid MINERU_HYBRID_BATCH_RATIO value: {env_val}, switching to auto mode. Error: {e}")
  313. # 2. 根据显存自动推断
  314. """
  315. 根据总显存大小粗略估计 batch ratio,需要排除掉vllm等推理框架占用的显存开销
  316. """
  317. gpu_memory = get_vram(device)
  318. if gpu_memory >= 32:
  319. batch_ratio = 16
  320. elif gpu_memory >= 16:
  321. batch_ratio = 8
  322. elif gpu_memory >= 12:
  323. batch_ratio = 4
  324. elif gpu_memory >= 8:
  325. batch_ratio = 2
  326. else:
  327. batch_ratio = 1
  328. logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}")
  329. return batch_ratio
  330. def _should_enable_vlm_ocr(ocr_enable: bool, language: str, inline_formula_enable: bool) -> bool:
  331. """判断是否启用VLM OCR"""
  332. force_enable = os.getenv("MINERU_FORCE_VLM_OCR_ENABLE", "0").lower() in ("1", "true", "yes")
  333. if force_enable:
  334. return True
  335. force_pipeline = os.getenv("MINERU_HYBRID_FORCE_PIPELINE_ENABLE", "0").lower() in ("1", "true", "yes")
  336. return (
  337. ocr_enable
  338. and language in ["ch", "en"]
  339. and inline_formula_enable
  340. and not force_pipeline
  341. )
  342. def doc_analyze(
  343. pdf_bytes,
  344. image_writer: DataWriter | None,
  345. predictor: MinerUClient | None = None,
  346. backend="transformers",
  347. parse_method: str = 'auto',
  348. language: str = 'ch',
  349. inline_formula_enable: bool = True,
  350. model_path: str | None = None,
  351. server_url: str | None = None,
  352. **kwargs,
  353. ):
  354. # 初始化预测器
  355. if predictor is None:
  356. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  357. # 加载图像
  358. load_images_start = time.time()
  359. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  360. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  361. load_images_time = round(time.time() - load_images_start, 2)
  362. logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
  363. # 获取设备信息
  364. device = get_device()
  365. # 确定OCR配置
  366. _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
  367. _vlm_ocr_enable = _should_enable_vlm_ocr(_ocr_enable, language, inline_formula_enable)
  368. infer_start = time.time()
  369. # VLM提取
  370. if _vlm_ocr_enable:
  371. results = predictor.batch_two_step_extract(images=images_pil_list)
  372. hybrid_pipeline_model = None
  373. inline_formula_list = [[] for _ in images_pil_list]
  374. ocr_res_list = [[] for _ in images_pil_list]
  375. else:
  376. batch_ratio = get_batch_ratio(device)
  377. results = predictor.batch_two_step_extract(
  378. images=images_pil_list,
  379. not_extract_list=not_extract_list
  380. )
  381. inline_formula_list, ocr_res_list, hybrid_pipeline_model = _process_ocr_and_formulas(
  382. images_pil_list,
  383. results,
  384. language,
  385. inline_formula_enable,
  386. _ocr_enable,
  387. batch_radio=batch_ratio,
  388. )
  389. _normalize_bbox(inline_formula_list, ocr_res_list, images_pil_list)
  390. infer_time = round(time.time() - infer_start, 2)
  391. logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  392. # 生成中间JSON
  393. middle_json = result_to_middle_json(
  394. results,
  395. inline_formula_list,
  396. ocr_res_list,
  397. images_list,
  398. pdf_doc,
  399. image_writer,
  400. _ocr_enable,
  401. _vlm_ocr_enable,
  402. hybrid_pipeline_model,
  403. )
  404. clean_memory(device)
  405. return middle_json, results, _vlm_ocr_enable
  406. async def aio_doc_analyze(
  407. pdf_bytes,
  408. image_writer: DataWriter | None,
  409. predictor: MinerUClient | None = None,
  410. backend="transformers",
  411. parse_method: str = 'auto',
  412. language: str = 'ch',
  413. inline_formula_enable: bool = True,
  414. model_path: str | None = None,
  415. server_url: str | None = None,
  416. **kwargs,
  417. ):
  418. # 初始化预测器
  419. if predictor is None:
  420. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  421. # 加载图像
  422. load_images_start = time.time()
  423. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  424. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  425. load_images_time = round(time.time() - load_images_start, 2)
  426. logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
  427. # 获取设备信息
  428. device = get_device()
  429. # 确定OCR配置
  430. _ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
  431. _vlm_ocr_enable = _should_enable_vlm_ocr(_ocr_enable, language, inline_formula_enable)
  432. infer_start = time.time()
  433. # VLM提取
  434. if _vlm_ocr_enable:
  435. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  436. hybrid_pipeline_model = None
  437. inline_formula_list = [[] for _ in images_pil_list]
  438. ocr_res_list = [[] for _ in images_pil_list]
  439. else:
  440. batch_ratio = get_batch_ratio(device)
  441. results = await predictor.aio_batch_two_step_extract(
  442. images=images_pil_list,
  443. not_extract_list=not_extract_list
  444. )
  445. inline_formula_list, ocr_res_list, hybrid_pipeline_model = _process_ocr_and_formulas(
  446. images_pil_list,
  447. results,
  448. language,
  449. inline_formula_enable,
  450. _ocr_enable,
  451. batch_radio=batch_ratio,
  452. )
  453. _normalize_bbox(inline_formula_list, ocr_res_list, images_pil_list)
  454. infer_time = round(time.time() - infer_start, 2)
  455. logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  456. # 生成中间JSON
  457. middle_json = result_to_middle_json(
  458. results,
  459. inline_formula_list,
  460. ocr_res_list,
  461. images_list,
  462. pdf_doc,
  463. image_writer,
  464. _ocr_enable,
  465. _vlm_ocr_enable,
  466. hybrid_pipeline_model,
  467. )
  468. clean_memory(device)
  469. return middle_json, results, _vlm_ocr_enable