vlm_analyze.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. import time
  4. import json
  5. from loguru import logger
  6. from .utils import enable_custom_logits_processors, set_default_gpu_memory_utilization, set_default_batch_size, \
  7. set_lmdeploy_backend
  8. from .model_output_to_middle_json import result_to_middle_json
  9. from ...data.data_reader_writer import DataWriter
  10. from mineru.utils.pdf_image_tools import load_images_from_pdf
  11. from ...utils.check_sys_env import is_mac_os_version_supported
  12. from ...utils.config_reader import get_device
  13. from ...utils.enum_class import ImageType
  14. from ...utils.models_download_utils import auto_download_and_get_model_root_path
  15. from mineru_vl_utils import MinerUClient
  16. from packaging import version
  17. class ModelSingleton:
  18. _instance = None
  19. _models = {}
  20. def __new__(cls, *args, **kwargs):
  21. if cls._instance is None:
  22. cls._instance = super().__new__(cls)
  23. return cls._instance
  24. def get_model(
  25. self,
  26. backend: str,
  27. model_path: str | None,
  28. server_url: str | None,
  29. **kwargs,
  30. ) -> MinerUClient:
  31. key = (backend, model_path, server_url)
  32. if key not in self._models:
  33. start_time = time.time()
  34. model = None
  35. processor = None
  36. vllm_llm = None
  37. lmdeploy_engine = None
  38. vllm_async_llm = None
  39. batch_size = kwargs.get("batch_size", 0) # for transformers backend only
  40. max_concurrency = kwargs.get("max_concurrency", 100) # for http-client backend only
  41. http_timeout = kwargs.get("http_timeout", 600) # for http-client backend only
  42. server_headers = kwargs.get("server_headers", None) # for http-client backend only
  43. max_retries = kwargs.get("max_retries", 3) # for http-client backend only
  44. retry_backoff_factor = kwargs.get("retry_backoff_factor", 0.5) # for http-client backend only
  45. # 从kwargs中移除这些参数,避免传递给不相关的初始化函数
  46. for param in ["batch_size", "max_concurrency", "http_timeout", "server_headers", "max_retries", "retry_backoff_factor"]:
  47. if param in kwargs:
  48. del kwargs[param]
  49. if backend not in ["http-client"] and not model_path:
  50. model_path = auto_download_and_get_model_root_path("/","vlm")
  51. if backend == "transformers":
  52. try:
  53. from transformers import (
  54. AutoProcessor,
  55. Qwen2VLForConditionalGeneration,
  56. )
  57. from transformers import __version__ as transformers_version
  58. except ImportError:
  59. raise ImportError("Please install transformers to use the transformers backend.")
  60. if version.parse(transformers_version) >= version.parse("4.56.0"):
  61. dtype_key = "dtype"
  62. else:
  63. dtype_key = "torch_dtype"
  64. device = get_device()
  65. model = Qwen2VLForConditionalGeneration.from_pretrained(
  66. model_path,
  67. device_map={"": device},
  68. **{dtype_key: "auto"}, # type: ignore
  69. )
  70. processor = AutoProcessor.from_pretrained(
  71. model_path,
  72. use_fast=True,
  73. )
  74. if batch_size == 0:
  75. batch_size = set_default_batch_size()
  76. elif backend == "mlx-engine":
  77. mlx_supported = is_mac_os_version_supported()
  78. if not mlx_supported:
  79. raise EnvironmentError("mlx-engine backend is only supported on macOS 13.5+ with Apple Silicon.")
  80. try:
  81. from mlx_vlm import load as mlx_load
  82. except ImportError:
  83. raise ImportError("Please install mlx-vlm to use the mlx-engine backend.")
  84. model, processor = mlx_load(model_path)
  85. else:
  86. if os.getenv('OMP_NUM_THREADS') is None:
  87. os.environ["OMP_NUM_THREADS"] = "1"
  88. if backend == "vllm-engine":
  89. try:
  90. import vllm
  91. except ImportError:
  92. raise ImportError("Please install vllm to use the vllm-engine backend.")
  93. """
  94. # musa vllm v1 引擎特殊配置
  95. device = get_device()
  96. if device.startswith("musa"):
  97. import torch
  98. if torch.musa.is_available():
  99. compilation_config = {
  100. "cudagraph_capture_sizes": [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
  101. "simple_cuda_graph": True
  102. }
  103. block_size = 32
  104. kwargs["compilation_config"] = compilation_config
  105. kwargs["block_size"] = block_size
  106. """
  107. if "compilation_config" in kwargs:
  108. if isinstance(kwargs["compilation_config"], str):
  109. try:
  110. kwargs["compilation_config"] = json.loads(kwargs["compilation_config"])
  111. except json.JSONDecodeError:
  112. logger.warning(
  113. f"Failed to parse compilation_config as JSON: {kwargs['compilation_config']}")
  114. del kwargs["compilation_config"]
  115. if "gpu_memory_utilization" not in kwargs:
  116. kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
  117. if "model" not in kwargs:
  118. kwargs["model"] = model_path
  119. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  120. from mineru_vl_utils import MinerULogitsProcessor
  121. kwargs["logits_processors"] = [MinerULogitsProcessor]
  122. # 使用kwargs为 vllm初始化参数
  123. vllm_llm = vllm.LLM(**kwargs)
  124. elif backend == "vllm-async-engine":
  125. try:
  126. from vllm.engine.arg_utils import AsyncEngineArgs
  127. from vllm.v1.engine.async_llm import AsyncLLM
  128. from vllm.config import CompilationConfig
  129. except ImportError:
  130. raise ImportError("Please install vllm to use the vllm-async-engine backend.")
  131. """
  132. # musa vllm v1 引擎特殊配置
  133. device = get_device()
  134. if device.startswith("musa"):
  135. import torch
  136. if torch.musa.is_available():
  137. compilation_config = CompilationConfig(
  138. cudagraph_capture_sizes=[1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 18, 20, 24, 28, 30],
  139. simple_cuda_graph=True
  140. )
  141. block_size = 32
  142. kwargs["compilation_config"] = compilation_config
  143. kwargs["block_size"] = block_size
  144. """
  145. if "compilation_config" in kwargs:
  146. if isinstance(kwargs["compilation_config"], dict):
  147. # 如果是字典,转换为 CompilationConfig 对象
  148. kwargs["compilation_config"] = CompilationConfig(**kwargs["compilation_config"])
  149. elif isinstance(kwargs["compilation_config"], str):
  150. # 如果是 JSON 字符串,先解析再转换
  151. try:
  152. config_dict = json.loads(kwargs["compilation_config"])
  153. kwargs["compilation_config"] = CompilationConfig(**config_dict)
  154. except (json.JSONDecodeError, TypeError) as e:
  155. logger.warning(
  156. f"Failed to parse compilation_config: {kwargs['compilation_config']}, error: {e}")
  157. del kwargs["compilation_config"]
  158. if "gpu_memory_utilization" not in kwargs:
  159. kwargs["gpu_memory_utilization"] = set_default_gpu_memory_utilization()
  160. if "model" not in kwargs:
  161. kwargs["model"] = model_path
  162. if enable_custom_logits_processors() and ("logits_processors" not in kwargs):
  163. from mineru_vl_utils import MinerULogitsProcessor
  164. kwargs["logits_processors"] = [MinerULogitsProcessor]
  165. # 使用kwargs为 vllm初始化参数
  166. vllm_async_llm = AsyncLLM.from_engine_args(AsyncEngineArgs(**kwargs))
  167. elif backend == "lmdeploy-engine":
  168. try:
  169. from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig
  170. from lmdeploy.serve.vl_async_engine import VLAsyncEngine
  171. except ImportError:
  172. raise ImportError("Please install lmdeploy to use the lmdeploy-engine backend.")
  173. if "cache_max_entry_count" not in kwargs:
  174. kwargs["cache_max_entry_count"] = 0.5
  175. device_type = os.getenv("MINERU_LMDEPLOY_DEVICE", "")
  176. if device_type == "":
  177. if "lmdeploy_device" in kwargs:
  178. device_type = kwargs.pop("lmdeploy_device")
  179. if device_type not in ["cuda", "ascend", "maca", "camb"]:
  180. raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
  181. else:
  182. device_type = "cuda"
  183. lm_backend = os.getenv("MINERU_LMDEPLOY_BACKEND", "")
  184. if lm_backend == "":
  185. if "lmdeploy_backend" in kwargs:
  186. lm_backend = kwargs.pop("lmdeploy_backend")
  187. if lm_backend not in ["pytorch", "turbomind"]:
  188. raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
  189. else:
  190. lm_backend = set_lmdeploy_backend(device_type)
  191. logger.info(f"lmdeploy device is: {device_type}, lmdeploy backend is: {lm_backend}")
  192. if lm_backend == "pytorch":
  193. kwargs["device_type"] = device_type
  194. backend_config = PytorchEngineConfig(**kwargs)
  195. elif lm_backend == "turbomind":
  196. backend_config = TurbomindEngineConfig(**kwargs)
  197. else:
  198. raise ValueError(f"Unsupported lmdeploy backend: {lm_backend}")
  199. log_level = 'ERROR'
  200. from lmdeploy.utils import get_logger
  201. lm_logger = get_logger('lmdeploy')
  202. lm_logger.setLevel(log_level)
  203. if os.getenv('TM_LOG_LEVEL') is None:
  204. os.environ['TM_LOG_LEVEL'] = log_level
  205. lmdeploy_engine = VLAsyncEngine(
  206. model_path,
  207. backend=lm_backend,
  208. backend_config=backend_config,
  209. )
  210. self._models[key] = MinerUClient(
  211. backend=backend,
  212. model=model,
  213. processor=processor,
  214. lmdeploy_engine=lmdeploy_engine,
  215. vllm_llm=vllm_llm,
  216. vllm_async_llm=vllm_async_llm,
  217. server_url=server_url,
  218. batch_size=batch_size,
  219. max_concurrency=max_concurrency,
  220. http_timeout=http_timeout,
  221. server_headers=server_headers,
  222. max_retries=max_retries,
  223. retry_backoff_factor=retry_backoff_factor,
  224. )
  225. elapsed = round(time.time() - start_time, 2)
  226. logger.info(f"get {backend} predictor cost: {elapsed}s")
  227. return self._models[key]
  228. def doc_analyze(
  229. pdf_bytes,
  230. image_writer: DataWriter | None,
  231. predictor: MinerUClient | None = None,
  232. backend="transformers",
  233. model_path: str | None = None,
  234. server_url: str | None = None,
  235. **kwargs,
  236. ):
  237. if predictor is None:
  238. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  239. load_images_start = time.time()
  240. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  241. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  242. load_images_time = round(time.time() - load_images_start, 2)
  243. logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
  244. infer_start = time.time()
  245. results = predictor.batch_two_step_extract(images=images_pil_list)
  246. infer_time = round(time.time() - infer_start, 2)
  247. logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  248. # middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
  249. middle_json = asyncio.run(result_to_middle_json(results, images_list, pdf_doc, image_writer))
  250. return middle_json, results
  251. async def aio_doc_analyze(
  252. pdf_bytes,
  253. image_writer: DataWriter | None,
  254. predictor: MinerUClient | None = None,
  255. backend="transformers",
  256. model_path: str | None = None,
  257. server_url: str | None = None,
  258. **kwargs,
  259. ):
  260. if predictor is None:
  261. predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
  262. load_images_start = time.time()
  263. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  264. images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
  265. load_images_time = round(time.time() - load_images_start, 2)
  266. logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
  267. infer_start = time.time()
  268. results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
  269. infer_time = round(time.time() - infer_start, 2)
  270. logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
  271. middle_json = await result_to_middle_json(results, images_list, pdf_doc, image_writer)
  272. return middle_json, results