load_model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from pymilvus import model
  2. import torch
  3. from transformers import AutoTokenizer, AutoModelForSequenceClassification
  4. device = "cuda:0" if torch.cuda.is_available() else "cpu"
  5. import os
  6. os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
  7. # 使用sentence transformer方式加载模型
  8. # embedding_path = r"/opt/vllm/models/BAAI/bge-m3" # 线上路径
  9. bge_m3_base_url = r"http://10.1.14.16:8787/v1"
  10. bge_me_model = "bge-m3"
  11. qwen_ed_base_url = r"http://10.1.14.16:8788/v1"
  12. qwen_ed_model = "Qwen3-Embedding"
  13. # embedding_path = r"G:/work/code/models/multilingual-e5-large-instruct/" # 本地路径
  14. # sentence_transformer_ef = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path,device=device)
  15. bge_m3_ef = model.dense.OpenAIEmbeddingFunction(
  16. model_name=bge_me_model, # Specify the model name
  17. api_key='YOUR_API_KEY', # Provide your OpenAI API key
  18. # dimensions=512, # Set the embedding dimensionality
  19. base_url=bge_m3_base_url
  20. )
  21. bge_m3_ef._openai_model_meta_info[bge_me_model]["dim"] = 1024
  22. qwen_ed_ef = model.dense.OpenAIEmbeddingFunction(
  23. model_name=qwen_ed_model, # Specify the model name
  24. api_key='YOUR_API_KEY', # Provide your OpenAI API key
  25. # dimensions=512, # Set the embedding dimensionality
  26. base_url=qwen_ed_base_url
  27. )
  28. qwen_ed_ef._openai_model_meta_info[qwen_ed_model]["dim"] = 1024
  29. # sentence_transformer_ef = model.hybrid.BGEM3EmbeddingFunction(model_name=embedding_path,device=device,use_fp16=False)
  30. # embedding_path_qwen = r"/opt/vllm/models/Qwen/Qwen3-Embedding-0.6B"
  31. # sentence_transformer_qwen = model.dense.SentenceTransformerEmbeddingFunction(model_name=embedding_path_qwen,device=device,
  32. # model_kwargs={
  33. # "attn_implementation": "flash_attention_2", # 加速推理
  34. # "device_map": "auto", # 自动设备分配
  35. # "torch_dtype": torch.float16
  36. # },
  37. # tokenizer_kwargs={
  38. # "padding_side": "left" # 左侧填充
  39. # },
  40. # )
  41. # from transformers import AutoTokenizer, AutoModel
  42. # import torch
  43. # class QwenEmbedding:
  44. # def __init__(self, model_path, device="cuda"):
  45. # self.tokenizer = AutoTokenizer.from_pretrained(model_path)
  46. # self.model = AutoModel.from_pretrained(
  47. # model_path,
  48. # trust_remote_code=True,
  49. # torch_dtype=torch.float32
  50. # ).to(device)
  51. # self.device = device
  52. # @torch.no_grad()
  53. # def __call__(self, texts):
  54. # inputs = self.tokenizer(
  55. # texts, padding=True, truncation=True, return_tensors="pt"
  56. # ).to(self.device)
  57. # outputs = self.model(**inputs)
  58. # # 一般取 last_hidden_state 的 CLS 或 mean-pool
  59. # emb = outputs.last_hidden_state.mean(dim=1)
  60. # return emb.cpu().numpy()
  61. # sentence_transformer_qwen = QwenEmbedding(embedding_path_qwen, "cuda:1")
  62. # # rerank模型
  63. # bce_rerank_model_path = r"/opt/vllm/models/BAAI/bge-reranker-v2-m3" # 线上路径
  64. # # bce_rerank_model_path = r"G:/work/code/models/bce-reranker-base_v1" # 本地路径
  65. # bce_rerank_tokenizer = AutoTokenizer.from_pretrained(bce_rerank_model_path)
  66. # bce_rerank_base_model = AutoModelForSequenceClassification.from_pretrained(bce_rerank_model_path).to(device)
  67. # # rerank模型
  68. # qwen_rerank_model_path = r"/opt/vllm/models/Qwen/Qwen3-Reranker-0.6B" # 线上路径
  69. # # bce_rerank_model_path = r"G:/work/code/models/bce-reranker-base_v1" # 本地路径
  70. # qwen_rerank_tokenizer = AutoTokenizer.from_pretrained(bce_rerank_model_path)
  71. # qwen_rerank_base_model = AutoModelForSequenceClassification.from_pretrained(bce_rerank_model_path).to(device)
  72. # rerank bge
  73. rerank_bge_url = "http://10.1.14.16:8791"
  74. rerank_bge_model = "bge-reranker-v2-m3"
  75. rerank_qwen_url = "http://10.1.14.16:8790"
  76. rerank_qwen_model = "Qwen3-Reranker-0.6B"
  77. from transformers import AutoTokenizer
  78. tokenizer = AutoTokenizer.from_pretrained(
  79. "/opt/vllm/models/Qwen/Qwen3-30B-A3B-Instruct-2507",
  80. trust_remote_code=True # 有自定义 tokenizer 时常用
  81. )