vllm_server.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import sys
  3. from mineru.backend.vlm.utils import set_default_gpu_memory_utilization, enable_custom_logits_processors
  4. from mineru.utils.config_reader import get_device
  5. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  6. from vllm.entrypoints.cli.main import main as vllm_main
  7. def main():
  8. args = sys.argv[1:]
  9. has_port_arg = False
  10. has_gpu_memory_utilization_arg = False
  11. has_logits_processors_arg = False
  12. has_block_size_arg = False
  13. has_compilation_config = False
  14. model_path = None
  15. model_arg_indices = []
  16. # 检查现有参数
  17. for i, arg in enumerate(args):
  18. if arg == "--port" or arg.startswith("--port="):
  19. has_port_arg = True
  20. if arg == "--gpu-memory-utilization" or arg.startswith("--gpu-memory-utilization="):
  21. has_gpu_memory_utilization_arg = True
  22. if arg == "--logits-processors" or arg.startswith("--logits-processors="):
  23. has_logits_processors_arg = True
  24. if arg == "--block-size" or arg.startswith("--block-size="):
  25. has_block_size_arg = True
  26. if arg == "--compilation-config" or arg.startswith("--compilation-config="):
  27. has_compilation_config = True
  28. if arg == "--model":
  29. if i + 1 < len(args):
  30. model_path = args[i + 1]
  31. model_arg_indices.extend([i, i + 1])
  32. elif arg.startswith("--model="):
  33. model_path = arg.split("=", 1)[1]
  34. model_arg_indices.append(i)
  35. # 从参数列表中移除 --model 参数
  36. if model_arg_indices:
  37. for index in sorted(model_arg_indices, reverse=True):
  38. args.pop(index)
  39. custom_logits_processors = enable_custom_logits_processors()
  40. # 添加默认参数
  41. if not has_port_arg:
  42. args.extend(["--port", "30000"])
  43. if not has_gpu_memory_utilization_arg:
  44. gpu_memory_utilization = str(set_default_gpu_memory_utilization())
  45. args.extend(["--gpu-memory-utilization", gpu_memory_utilization])
  46. if not model_path:
  47. model_path = auto_download_and_get_model_root_path("/", "vlm")
  48. if (not has_logits_processors_arg) and custom_logits_processors:
  49. args.extend(["--logits-processors", "mineru_vl_utils:MinerULogitsProcessor"])
  50. """
  51. # musa vllm v1 引擎特殊配置
  52. device = get_device()
  53. if device.startswith("musa"):
  54. import torch
  55. if torch.musa.is_available():
  56. if not has_block_size_arg:
  57. args.extend(["--block-size", "32"])
  58. if not has_compilation_config:
  59. args.extend(["--compilation-config", '{"cudagraph_capture_sizes": [1,2,3,4,5,6,7,8,10,12,14,16,18,20,24,28,30], "simple_cuda_graph": true}'])
  60. """
  61. # 重构参数,将模型路径作为位置参数
  62. sys.argv = [sys.argv[0]] + ["serve", model_path] + args
  63. if os.getenv('OMP_NUM_THREADS') is None:
  64. os.environ["OMP_NUM_THREADS"] = "1"
  65. # 启动vllm服务器
  66. print(f"start vllm server: {sys.argv}")
  67. vllm_main()
  68. if __name__ == "__main__":
  69. main()