engine_utils.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. from loguru import logger
  3. from mineru.utils.check_sys_env import is_mac_os_version_supported, is_windows_environment, is_mac_environment, \
  4. is_linux_environment
  5. def get_vlm_engine(inference_engine: str, is_async: bool = False) -> str:
  6. """
  7. 自动选择或验证 VLM 推理引擎
  8. Args:
  9. inference_engine: 指定的引擎名称或 'auto' 进行自动选择
  10. is_async: 是否使用异步引擎(仅对 vllm 有效)
  11. Returns:
  12. 最终选择的引擎名称
  13. """
  14. if inference_engine == 'auto':
  15. # 根据操作系统自动选择引擎
  16. if is_windows_environment():
  17. inference_engine = _select_windows_engine()
  18. elif is_linux_environment():
  19. inference_engine = _select_linux_engine(is_async)
  20. elif is_mac_environment():
  21. inference_engine = _select_mac_engine()
  22. else:
  23. logger.warning("Unknown operating system, falling back to transformers")
  24. inference_engine = 'transformers'
  25. formatted_engine = _format_engine_name(inference_engine)
  26. logger.info(f"Using {formatted_engine} as the inference engine for VLM.")
  27. return formatted_engine
  28. def _select_windows_engine() -> str:
  29. """Windows 平台引擎选择"""
  30. try:
  31. import lmdeploy
  32. return 'lmdeploy'
  33. except ImportError:
  34. return 'transformers'
  35. def _select_linux_engine(is_async: bool) -> str:
  36. """Linux 平台引擎选择"""
  37. try:
  38. import vllm
  39. return 'vllm-async' if is_async else 'vllm'
  40. except ImportError:
  41. try:
  42. import lmdeploy
  43. return 'lmdeploy'
  44. except ImportError:
  45. return 'transformers'
  46. def _select_mac_engine() -> str:
  47. """macOS 平台引擎选择"""
  48. try:
  49. from mlx_vlm import load as mlx_load
  50. if is_mac_os_version_supported():
  51. return 'mlx'
  52. else:
  53. return 'transformers'
  54. except ImportError:
  55. return 'transformers'
  56. def _format_engine_name(engine: str) -> str:
  57. """统一格式化引擎名称"""
  58. if engine != 'transformers':
  59. return f"{engine}-engine"
  60. return engine