Unimernet.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import torch
  2. from torch.utils.data import DataLoader, Dataset
  3. from tqdm import tqdm
  4. from mineru.utils.boxbase import calculate_iou
  5. class MathDataset(Dataset):
  6. def __init__(self, image_paths, transform=None):
  7. self.image_paths = image_paths
  8. self.transform = transform
  9. def __len__(self):
  10. return len(self.image_paths)
  11. def __getitem__(self, idx):
  12. raw_image = self.image_paths[idx]
  13. if self.transform:
  14. image = self.transform(raw_image)
  15. return image
  16. class UnimernetModel(object):
  17. def __init__(self, weight_dir, _device_="cpu"):
  18. from .unimernet_hf import UnimernetModel
  19. if _device_.startswith("mps") or _device_.startswith("npu") or _device_.startswith("musa"):
  20. self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
  21. else:
  22. self.model = UnimernetModel.from_pretrained(weight_dir)
  23. self.device = torch.device(_device_)
  24. self.model.to(self.device)
  25. if not _device_.startswith("cpu"):
  26. self.model = self.model.to(dtype=torch.float16)
  27. self.model.eval()
  28. @staticmethod
  29. def _filter_boxes_by_iou(xyxy, conf, cla, iou_threshold=0.8):
  30. """过滤IOU超过阈值的重叠框,保留置信度较高的框。
  31. Args:
  32. xyxy: 框坐标张量,shape为(N, 4)
  33. conf: 置信度张量,shape为(N,)
  34. cla: 类别张量,shape为(N,)
  35. iou_threshold: IOU阈值,默认0.9
  36. Returns:
  37. 过滤后的xyxy, conf, cla张量
  38. """
  39. if len(xyxy) == 0:
  40. return xyxy, conf, cla
  41. # 转换为CPU进行处理
  42. xyxy_cpu = xyxy.cpu()
  43. conf_cpu = conf.cpu()
  44. n = len(xyxy_cpu)
  45. keep = [True] * n
  46. for i in range(n):
  47. if not keep[i]:
  48. continue
  49. bbox1 = xyxy_cpu[i].tolist()
  50. for j in range(i + 1, n):
  51. if not keep[j]:
  52. continue
  53. bbox2 = xyxy_cpu[j].tolist()
  54. iou = calculate_iou(bbox1, bbox2)
  55. if iou > iou_threshold:
  56. # 保留置信度较高的框
  57. if conf_cpu[i] >= conf_cpu[j]:
  58. keep[j] = False
  59. else:
  60. keep[i] = False
  61. break # i被删除,跳出内循环
  62. keep_indices = [i for i in range(n) if keep[i]]
  63. if len(keep_indices) == n:
  64. return xyxy, conf, cla
  65. keep_indices = torch.tensor(keep_indices, dtype=torch.long)
  66. return xyxy[keep_indices], conf[keep_indices], cla[keep_indices]
  67. def predict(self, mfd_res, image):
  68. formula_list = []
  69. mf_image_list = []
  70. # 对检测框进行IOU去重,保留置信度较高的框
  71. xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
  72. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  73. )
  74. for xyxy, conf, cla in zip(
  75. xyxy_filtered.cpu(), conf_filtered.cpu(), cla_filtered.cpu()
  76. ):
  77. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  78. new_item = {
  79. "category_id": 13 + int(cla.item()),
  80. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  81. "score": round(float(conf.item()), 2),
  82. "latex": "",
  83. }
  84. formula_list.append(new_item)
  85. bbox_img = image[ymin:ymax, xmin:xmax]
  86. mf_image_list.append(bbox_img)
  87. dataset = MathDataset(mf_image_list, transform=self.model.transform)
  88. dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
  89. mfr_res = []
  90. for mf_img in dataloader:
  91. mf_img = mf_img.to(dtype=self.model.dtype)
  92. mf_img = mf_img.to(self.device)
  93. with torch.no_grad():
  94. output = self.model.generate({"image": mf_img})
  95. mfr_res.extend(output["fixed_str"])
  96. for res, latex in zip(formula_list, mfr_res):
  97. res["latex"] = latex
  98. return formula_list
  99. def batch_predict(
  100. self,
  101. images_mfd_res: list,
  102. images: list,
  103. batch_size: int = 64,
  104. interline_enable: bool = True,
  105. ) -> list:
  106. images_formula_list = []
  107. mf_image_list = []
  108. backfill_list = []
  109. image_info = [] # Store (area, original_index, image) tuples
  110. # Collect images with their original indices
  111. for image_index in range(len(images_mfd_res)):
  112. mfd_res = images_mfd_res[image_index]
  113. image = images[image_index]
  114. formula_list = []
  115. # 对检测框进行IOU去重,保留置信度较高的框
  116. xyxy_filtered, conf_filtered, cla_filtered = self._filter_boxes_by_iou(
  117. mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
  118. )
  119. for idx, (xyxy, conf, cla) in enumerate(zip(
  120. xyxy_filtered, conf_filtered, cla_filtered
  121. )):
  122. if not interline_enable and cla.item() == 1:
  123. continue # Skip interline regions if not enabled
  124. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  125. new_item = {
  126. "category_id": 13 + int(cla.item()),
  127. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  128. "score": round(float(conf.item()), 2),
  129. "latex": "",
  130. }
  131. formula_list.append(new_item)
  132. bbox_img = image[ymin:ymax, xmin:xmax]
  133. area = (xmax - xmin) * (ymax - ymin)
  134. curr_idx = len(mf_image_list)
  135. image_info.append((area, curr_idx, bbox_img))
  136. mf_image_list.append(bbox_img)
  137. images_formula_list.append(formula_list)
  138. backfill_list += formula_list
  139. # Stable sort by area
  140. image_info.sort(key=lambda x: x[0]) # sort by area
  141. sorted_indices = [x[1] for x in image_info]
  142. sorted_images = [x[2] for x in image_info]
  143. # Create mapping for results
  144. index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
  145. # Create dataset with sorted images
  146. dataset = MathDataset(sorted_images, transform=self.model.transform)
  147. # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
  148. batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
  149. dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
  150. # Process batches and store results
  151. mfr_res = []
  152. # for mf_img in dataloader:
  153. with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
  154. for index, mf_img in enumerate(dataloader):
  155. mf_img = mf_img.to(dtype=self.model.dtype)
  156. mf_img = mf_img.to(self.device)
  157. with torch.no_grad():
  158. output = self.model.generate({"image": mf_img}, batch_size=batch_size)
  159. mfr_res.extend(output["fixed_str"])
  160. # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
  161. current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
  162. pbar.update(current_batch_size)
  163. # Restore original order
  164. unsorted_results = [""] * len(mfr_res)
  165. for new_idx, latex in enumerate(mfr_res):
  166. original_idx = index_mapping[new_idx]
  167. unsorted_results[original_idx] = latex
  168. # Fill results back
  169. for res, latex in zip(backfill_list, unsorted_results):
  170. res["latex"] = latex
  171. return images_formula_list