modeling_unimernet.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import os
  2. import warnings
  3. from typing import Optional
  4. import torch
  5. from ftfy import fix_text
  6. from loguru import logger
  7. from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
  8. from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
  9. from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
  10. from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
  11. from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
  12. from ...utils import latex_rm_whitespace
  13. AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
  14. AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
  15. AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
  16. AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
  17. # TODO: rewrite tokenizer
  18. class TokenizerWrapper:
  19. def __init__(self, tokenizer):
  20. self.tokenizer = tokenizer
  21. self.pad_token_id = self.tokenizer.pad_token_id
  22. self.bos_token_id = self.tokenizer.bos_token_id
  23. self.eos_token_id = self.tokenizer.eos_token_id
  24. def __len__(self):
  25. return len(self.tokenizer)
  26. def tokenize(self, text, **kwargs):
  27. return self.tokenizer(
  28. text,
  29. return_token_type_ids=False,
  30. return_tensors="pt",
  31. padding="longest",
  32. truncation=True,
  33. **kwargs,
  34. )
  35. def token2str(self, tokens) -> list:
  36. generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
  37. generated_text = [fix_text(text) for text in generated_text]
  38. return generated_text
  39. def detokenize(self, tokens):
  40. toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
  41. for b in range(len(toks)):
  42. for i in reversed(range(len(toks[b]))):
  43. if toks[b][i] is None:
  44. toks[b][i] = ''
  45. toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
  46. if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
  47. del toks[b][i]
  48. return toks
  49. class UnimernetModel(VisionEncoderDecoderModel):
  50. def __init__(
  51. self,
  52. config: Optional[PretrainedConfig] = None,
  53. encoder: Optional[PreTrainedModel] = None,
  54. decoder: Optional[PreTrainedModel] = None,
  55. ):
  56. # VisionEncoderDecoderModel's checking log has bug, disable for temp.
  57. base_model_logger.disabled = True
  58. try:
  59. super().__init__(config, encoder, decoder)
  60. finally:
  61. base_model_logger.disabled = False
  62. if not config or not hasattr(config, "_name_or_path"):
  63. raise RuntimeError("config._name_or_path is required by UnimernetModel.")
  64. model_path = config._name_or_path
  65. self.transform = UnimerSwinImageProcessor()
  66. self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
  67. self._post_check()
  68. def _post_check(self):
  69. tokenizer = self.tokenizer
  70. if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
  71. warnings.warn(
  72. f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
  73. f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
  74. f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
  75. tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
  76. assert self.config.decoder.vocab_size == len(tokenizer)
  77. assert self.config.decoder_start_token_id == tokenizer.bos_token_id
  78. assert self.config.pad_token_id == tokenizer.pad_token_id
  79. @classmethod
  80. def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
  81. config = VisionEncoderDecoderConfig.from_pretrained(model_path)
  82. config._name_or_path = model_path
  83. config.encoder = UnimerSwinConfig(**vars(config.encoder))
  84. config.decoder = UnimerMBartConfig(**vars(config.decoder))
  85. encoder = UnimerSwinModel(config.encoder)
  86. decoder = UnimerMBartForCausalLM(config.decoder)
  87. model = cls(config, encoder, decoder)
  88. # load model weights
  89. model_file_path = os.path.join(model_path, model_filename)
  90. checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
  91. state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
  92. if not state_dict:
  93. raise RuntimeError("state_dict is empty.")
  94. if state_dict_strip_prefix:
  95. state_dict = {
  96. k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
  97. for k, v in state_dict.items()
  98. }
  99. missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
  100. if len(unexpected_keys) > 0:
  101. warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
  102. if len(missing_keys) > 0:
  103. raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
  104. return model
  105. def forward_bak(self, samples):
  106. pixel_values, text = samples["image"], samples["text_input"]
  107. text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
  108. decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
  109. num_channels = pixel_values.shape[1]
  110. if num_channels == 1:
  111. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  112. labels = decoder_input_ids * 1
  113. labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
  114. loss = self.model(
  115. pixel_values=pixel_values,
  116. decoder_input_ids=decoder_input_ids[:, :-1],
  117. decoder_attention_mask=decoder_attention_mask[:, :-1],
  118. labels=labels[:, 1:],
  119. ).loss
  120. return {"loss": loss}
  121. def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95, batch_size=64):
  122. pixel_values = samples["image"]
  123. num_channels = pixel_values.shape[1]
  124. if num_channels == 1:
  125. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  126. kwargs = {}
  127. if do_sample:
  128. kwargs["temperature"] = temperature
  129. kwargs["top_p"] = top_p
  130. if self.tokenizer.tokenizer.model_max_length > 1152:
  131. if batch_size <= 32:
  132. self.tokenizer.tokenizer.model_max_length = 1152 # 6g
  133. else:
  134. self.tokenizer.tokenizer.model_max_length = 1344 # 8g
  135. outputs = super().generate(
  136. pixel_values=pixel_values,
  137. max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
  138. decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
  139. do_sample=do_sample,
  140. **kwargs,
  141. )
  142. outputs = outputs[:, 1:].cpu().numpy()
  143. pred_tokens = self.tokenizer.detokenize(outputs)
  144. pred_str = self.tokenizer.token2str(outputs)
  145. fixed_str = [latex_rm_whitespace(s) for s in pred_str]
  146. return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}