model_init.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import os
  2. import torch
  3. from loguru import logger
  4. from .model_list import AtomicModel
  5. from ...model.layout.doclayoutyolo import DocLayoutYOLOModel
  6. from ...model.mfd.yolo_v8 import YOLOv8MFDModel
  7. from ...model.mfr.unimernet.Unimernet import UnimernetModel
  8. from ...model.mfr.pp_formulanet_plus_m.predict_formula import FormulaRecognizer
  9. from mineru.model.ocr.pytorch_paddle import PytorchPaddleOCR
  10. from ...model.ori_cls.paddle_ori_cls import PaddleOrientationClsModel
  11. from ...model.table.cls.paddle_table_cls import PaddleTableClsModel
  12. # from ...model.table.rec.RapidTable import RapidTableModel
  13. from ...model.table.rec.slanet_plus.main import RapidTableModel
  14. from ...model.table.rec.unet_table.main import UnetTableModel
  15. from ...utils.config_reader import get_device
  16. from ...utils.enum_class import ModelPath
  17. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  18. MFR_MODEL = os.getenv('MINERU_FORMULA_CH_SUPPORT', 'False')
  19. if MFR_MODEL.lower() in ['true', '1', 'yes']:
  20. MFR_MODEL = "pp_formulanet_plus_m"
  21. elif MFR_MODEL.lower() in ['false', '0', 'no']:
  22. MFR_MODEL = "unimernet_small"
  23. else:
  24. logger.warning(f"Invalid MINERU_FORMULA_CH_SUPPORT value: {MFR_MODEL}, set to default 'False'")
  25. MFR_MODEL = "unimernet_small"
  26. def img_orientation_cls_model_init():
  27. atom_model_manager = AtomModelSingleton()
  28. ocr_engine = atom_model_manager.get_atom_model(
  29. atom_model_name=AtomicModel.OCR,
  30. det_db_box_thresh=0.5,
  31. det_db_unclip_ratio=1.6,
  32. lang="ch_lite",
  33. enable_merge_det_boxes=False
  34. )
  35. cls_model = PaddleOrientationClsModel(ocr_engine)
  36. return cls_model
  37. def table_cls_model_init():
  38. return PaddleTableClsModel()
  39. def wired_table_model_init(lang=None):
  40. atom_model_manager = AtomModelSingleton()
  41. ocr_engine = atom_model_manager.get_atom_model(
  42. atom_model_name=AtomicModel.OCR,
  43. det_db_box_thresh=0.5,
  44. det_db_unclip_ratio=1.6,
  45. lang=lang,
  46. enable_merge_det_boxes=False
  47. )
  48. table_model = UnetTableModel(ocr_engine)
  49. return table_model
  50. def wireless_table_model_init(lang=None):
  51. atom_model_manager = AtomModelSingleton()
  52. ocr_engine = atom_model_manager.get_atom_model(
  53. atom_model_name=AtomicModel.OCR,
  54. det_db_box_thresh=0.5,
  55. det_db_unclip_ratio=1.6,
  56. lang=lang,
  57. enable_merge_det_boxes=False
  58. )
  59. table_model = RapidTableModel(ocr_engine)
  60. return table_model
  61. def mfd_model_init(weight, device='cpu'):
  62. if str(device).startswith('npu'):
  63. device = torch.device(device)
  64. mfd_model = YOLOv8MFDModel(weight, device)
  65. return mfd_model
  66. def mfr_model_init(weight_dir, device='cpu'):
  67. if MFR_MODEL == "unimernet_small":
  68. mfr_model = UnimernetModel(weight_dir, device)
  69. elif MFR_MODEL == "pp_formulanet_plus_m":
  70. mfr_model = FormulaRecognizer(weight_dir, device)
  71. else:
  72. logger.error('MFR model name not allow')
  73. exit(1)
  74. return mfr_model
  75. def doclayout_yolo_model_init(weight, device='cpu'):
  76. if str(device).startswith('npu'):
  77. device = torch.device(device)
  78. model = DocLayoutYOLOModel(weight, device)
  79. return model
  80. def ocr_model_init(det_db_box_thresh=0.3,
  81. lang=None,
  82. det_db_unclip_ratio=1.8,
  83. enable_merge_det_boxes=True
  84. ):
  85. if lang is not None and lang != '':
  86. model = PytorchPaddleOCR(
  87. det_db_box_thresh=det_db_box_thresh,
  88. lang=lang,
  89. use_dilation=True,
  90. det_db_unclip_ratio=det_db_unclip_ratio,
  91. enable_merge_det_boxes=enable_merge_det_boxes,
  92. )
  93. else:
  94. model = PytorchPaddleOCR(
  95. det_db_box_thresh=det_db_box_thresh,
  96. use_dilation=True,
  97. det_db_unclip_ratio=det_db_unclip_ratio,
  98. enable_merge_det_boxes=enable_merge_det_boxes,
  99. )
  100. return model
  101. class AtomModelSingleton:
  102. _instance = None
  103. _models = {}
  104. def __new__(cls, *args, **kwargs):
  105. if cls._instance is None:
  106. cls._instance = super().__new__(cls)
  107. return cls._instance
  108. def get_atom_model(self, atom_model_name: str, **kwargs):
  109. lang = kwargs.get('lang', None)
  110. if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
  111. key = (
  112. atom_model_name,
  113. lang
  114. )
  115. elif atom_model_name in [AtomicModel.OCR]:
  116. key = (
  117. atom_model_name,
  118. kwargs.get('det_db_box_thresh', 0.3),
  119. lang,
  120. kwargs.get('det_db_unclip_ratio', 1.8),
  121. kwargs.get('enable_merge_det_boxes', True)
  122. )
  123. else:
  124. key = atom_model_name
  125. if key not in self._models:
  126. self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
  127. return self._models[key]
  128. def atom_model_init(model_name: str, **kwargs):
  129. atom_model = None
  130. if model_name == AtomicModel.Layout:
  131. atom_model = doclayout_yolo_model_init(
  132. kwargs.get('doclayout_yolo_weights'),
  133. kwargs.get('device')
  134. )
  135. elif model_name == AtomicModel.MFD:
  136. atom_model = mfd_model_init(
  137. kwargs.get('mfd_weights'),
  138. kwargs.get('device')
  139. )
  140. elif model_name == AtomicModel.MFR:
  141. atom_model = mfr_model_init(
  142. kwargs.get('mfr_weight_dir'),
  143. kwargs.get('device')
  144. )
  145. elif model_name == AtomicModel.OCR:
  146. atom_model = ocr_model_init(
  147. kwargs.get('det_db_box_thresh', 0.3),
  148. kwargs.get('lang'),
  149. kwargs.get('det_db_unclip_ratio', 1.8),
  150. kwargs.get('enable_merge_det_boxes', True)
  151. )
  152. elif model_name == AtomicModel.WirelessTable:
  153. atom_model = wireless_table_model_init(
  154. kwargs.get('lang'),
  155. )
  156. elif model_name == AtomicModel.WiredTable:
  157. atom_model = wired_table_model_init(
  158. kwargs.get('lang'),
  159. )
  160. elif model_name == AtomicModel.TableCls:
  161. atom_model = table_cls_model_init()
  162. elif model_name == AtomicModel.ImgOrientationCls:
  163. atom_model = img_orientation_cls_model_init()
  164. else:
  165. logger.error('model name not allow')
  166. exit(1)
  167. if atom_model is None:
  168. logger.error('model init failed')
  169. exit(1)
  170. else:
  171. return atom_model
  172. class MineruPipelineModel:
  173. def __init__(self, **kwargs):
  174. self.formula_config = kwargs.get('formula_config')
  175. self.apply_formula = self.formula_config.get('enable', True)
  176. self.table_config = kwargs.get('table_config')
  177. self.apply_table = self.table_config.get('enable', True)
  178. self.lang = kwargs.get('lang', None)
  179. self.device = kwargs.get('device', 'cpu')
  180. logger.info(
  181. 'DocAnalysis init, this may take some times......'
  182. )
  183. atom_model_manager = AtomModelSingleton()
  184. if self.apply_formula:
  185. # 初始化公式检测模型
  186. self.mfd_model = atom_model_manager.get_atom_model(
  187. atom_model_name=AtomicModel.MFD,
  188. mfd_weights=str(
  189. os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
  190. ),
  191. device=self.device,
  192. )
  193. # 初始化公式解析模型
  194. if MFR_MODEL == "unimernet_small":
  195. mfr_model_path = ModelPath.unimernet_small
  196. elif MFR_MODEL == "pp_formulanet_plus_m":
  197. mfr_model_path = ModelPath.pp_formulanet_plus_m
  198. else:
  199. logger.error('MFR model name not allow')
  200. exit(1)
  201. self.mfr_model = atom_model_manager.get_atom_model(
  202. atom_model_name=AtomicModel.MFR,
  203. mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
  204. device=self.device,
  205. )
  206. # 初始化layout模型
  207. self.layout_model = atom_model_manager.get_atom_model(
  208. atom_model_name=AtomicModel.Layout,
  209. doclayout_yolo_weights=str(
  210. os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
  211. ),
  212. device=self.device,
  213. )
  214. # 初始化ocr
  215. self.ocr_model = atom_model_manager.get_atom_model(
  216. atom_model_name=AtomicModel.OCR,
  217. det_db_box_thresh=0.3,
  218. lang=self.lang
  219. )
  220. # init table model
  221. if self.apply_table:
  222. self.wired_table_model = atom_model_manager.get_atom_model(
  223. atom_model_name=AtomicModel.WiredTable,
  224. lang=self.lang,
  225. )
  226. self.wireless_table_model = atom_model_manager.get_atom_model(
  227. atom_model_name=AtomicModel.WirelessTable,
  228. lang=self.lang,
  229. )
  230. self.table_cls_model = atom_model_manager.get_atom_model(
  231. atom_model_name=AtomicModel.TableCls,
  232. )
  233. self.img_orientation_cls_model = atom_model_manager.get_atom_model(
  234. atom_model_name=AtomicModel.ImgOrientationCls,
  235. lang=self.lang,
  236. )
  237. logger.info('DocAnalysis init done!')
  238. class HybridModelSingleton:
  239. _instance = None
  240. _models = {}
  241. def __new__(cls, *args, **kwargs):
  242. if cls._instance is None:
  243. cls._instance = super().__new__(cls)
  244. return cls._instance
  245. def get_model(
  246. self,
  247. lang=None,
  248. formula_enable=None,
  249. ):
  250. key = (lang, formula_enable)
  251. if key not in self._models:
  252. self._models[key] = MineruHybridModel(
  253. lang=lang,
  254. formula_enable=formula_enable,
  255. )
  256. return self._models[key]
  257. def ocr_det_batch_setting(device):
  258. # 检测torch的版本号
  259. import torch
  260. from packaging import version
  261. if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
  262. enable_ocr_det_batch = False
  263. else:
  264. enable_ocr_det_batch = True
  265. return enable_ocr_det_batch
  266. class MineruHybridModel:
  267. def __init__(
  268. self,
  269. device=None,
  270. lang=None,
  271. formula_enable=True,
  272. ):
  273. if device is not None:
  274. self.device = device
  275. else:
  276. self.device = get_device()
  277. self.lang = lang
  278. self.enable_ocr_det_batch = ocr_det_batch_setting(self.device)
  279. if str(self.device).startswith('npu'):
  280. try:
  281. import torch_npu
  282. if torch_npu.npu.is_available():
  283. torch_npu.npu.set_compile_mode(jit_compile=False)
  284. except Exception as e:
  285. raise RuntimeError(
  286. "NPU is selected as device, but torch_npu is not available. "
  287. "Please ensure that the torch_npu package is installed correctly."
  288. ) from e
  289. self.atom_model_manager = AtomModelSingleton()
  290. # 初始化OCR模型
  291. self.ocr_model = self.atom_model_manager.get_atom_model(
  292. atom_model_name=AtomicModel.OCR,
  293. det_db_box_thresh=0.3,
  294. lang=self.lang
  295. )
  296. if formula_enable:
  297. # 初始化公式检测模型
  298. self.mfd_model = self.atom_model_manager.get_atom_model(
  299. atom_model_name=AtomicModel.MFD,
  300. mfd_weights=str(
  301. os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
  302. ),
  303. device=self.device,
  304. )
  305. # 初始化公式解析模型
  306. if MFR_MODEL == "unimernet_small":
  307. mfr_model_path = ModelPath.unimernet_small
  308. elif MFR_MODEL == "pp_formulanet_plus_m":
  309. mfr_model_path = ModelPath.pp_formulanet_plus_m
  310. else:
  311. logger.error('MFR model name not allow')
  312. exit(1)
  313. self.mfr_model = self.atom_model_manager.get_atom_model(
  314. atom_model_name=AtomicModel.MFR,
  315. mfr_weight_dir=str(os.path.join(auto_download_and_get_model_root_path(mfr_model_path), mfr_model_path)),
  316. device=self.device,
  317. )