utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import os
  2. from loguru import logger
  3. from packaging import version
  4. from mineru.utils.check_sys_env import is_windows_environment, is_linux_environment
  5. from mineru.utils.config_reader import get_device
  6. from mineru.utils.model_utils import get_vram
  7. def enable_custom_logits_processors() -> bool:
  8. import torch
  9. from vllm import __version__ as vllm_version
  10. if torch.cuda.is_available():
  11. major, minor = torch.cuda.get_device_capability()
  12. # 正确计算Compute Capability
  13. compute_capability = f"{major}.{minor}"
  14. elif hasattr(torch, 'npu') and torch.npu.is_available():
  15. compute_capability = "8.0"
  16. elif hasattr(torch, 'gcu') and torch.gcu.is_available():
  17. compute_capability = "8.0"
  18. elif hasattr(torch, 'musa') and torch.musa.is_available():
  19. compute_capability = "8.0"
  20. else:
  21. logger.info("CUDA not available, disabling custom_logits_processors")
  22. return False
  23. # 安全地处理环境变量
  24. vllm_use_v1_str = os.getenv('VLLM_USE_V1', "1")
  25. if vllm_use_v1_str.isdigit():
  26. vllm_use_v1 = int(vllm_use_v1_str)
  27. else:
  28. vllm_use_v1 = 1
  29. if vllm_use_v1 == 0:
  30. logger.info("VLLM_USE_V1 is set to 0, disabling custom_logits_processors")
  31. return False
  32. elif version.parse(vllm_version) < version.parse("0.10.1"):
  33. logger.info(f"vllm version: {vllm_version} < 0.10.1, disable custom_logits_processors")
  34. return False
  35. elif version.parse(compute_capability) < version.parse("8.0"):
  36. if version.parse(vllm_version) >= version.parse("0.10.2"):
  37. logger.info(f"compute_capability: {compute_capability} < 8.0, but vllm version: {vllm_version} >= 0.10.2, enable custom_logits_processors")
  38. return True
  39. else:
  40. logger.info(f"compute_capability: {compute_capability} < 8.0 and vllm version: {vllm_version} < 0.10.2, disable custom_logits_processors")
  41. return False
  42. else:
  43. logger.info(f"compute_capability: {compute_capability} >= 8.0 and vllm version: {vllm_version} >= 0.10.1, enable custom_logits_processors")
  44. return True
  45. def set_lmdeploy_backend(device_type: str) -> str:
  46. if device_type.lower() in ["ascend", "maca", "camb"]:
  47. lmdeploy_backend = "pytorch"
  48. elif device_type.lower() in ["cuda"]:
  49. import torch
  50. if not torch.cuda.is_available():
  51. raise ValueError("CUDA is not available.")
  52. if is_windows_environment():
  53. lmdeploy_backend = "turbomind"
  54. elif is_linux_environment():
  55. major, minor = torch.cuda.get_device_capability()
  56. compute_capability = f"{major}.{minor}"
  57. if version.parse(compute_capability) >= version.parse("8.0"):
  58. lmdeploy_backend = "pytorch"
  59. else:
  60. lmdeploy_backend = "turbomind"
  61. else:
  62. raise ValueError("Unsupported operating system.")
  63. else:
  64. raise ValueError(f"Unsupported lmdeploy device type: {device_type}")
  65. return lmdeploy_backend
  66. def set_default_gpu_memory_utilization() -> float:
  67. from vllm import __version__ as vllm_version
  68. device = get_device()
  69. gpu_memory = get_vram(device)
  70. if version.parse(vllm_version) >= version.parse("0.11.0") and gpu_memory <= 8:
  71. return 0.7
  72. else:
  73. return 0.5
  74. def set_default_batch_size() -> int:
  75. try:
  76. device = get_device()
  77. gpu_memory = get_vram(device)
  78. if gpu_memory >= 16:
  79. batch_size = 8
  80. elif gpu_memory >= 8:
  81. batch_size = 4
  82. else:
  83. batch_size = 1
  84. logger.info(f'gpu_memory: {gpu_memory} GB, batch_size: {batch_size}')
  85. except Exception as e:
  86. logger.warning(f'Error determining VRAM: {e}, using default batch_ratio: 1')
  87. batch_size = 1
  88. return batch_size