yolo_v8.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import os
  2. from typing import List, Union
  3. import torch
  4. from tqdm import tqdm
  5. from ultralytics import YOLO
  6. import numpy as np
  7. from PIL import Image, ImageDraw
  8. from mineru.utils.enum_class import ModelPath
  9. from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
  10. class YOLOv8MFDModel:
  11. def __init__(
  12. self,
  13. weight: str,
  14. device: str = "cpu",
  15. imgsz: int = 1888,
  16. conf: float = 0.25,
  17. iou: float = 0.45,
  18. ):
  19. self.device = torch.device(device)
  20. self.model = YOLO(weight).to(self.device)
  21. self.imgsz = imgsz
  22. self.conf = conf
  23. self.iou = iou
  24. def _run_predict(
  25. self,
  26. inputs: Union[np.ndarray, Image.Image, List],
  27. is_batch: bool = False,
  28. conf: float = None,
  29. ) -> List:
  30. preds = self.model.predict(
  31. inputs,
  32. imgsz=self.imgsz,
  33. conf=conf if conf is not None else self.conf,
  34. iou=self.iou,
  35. verbose=False,
  36. device=self.device
  37. )
  38. return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
  39. def predict(
  40. self,
  41. image: Union[np.ndarray, Image.Image],
  42. conf: float = None,
  43. ):
  44. return self._run_predict(image, is_batch=False, conf=conf)
  45. def batch_predict(
  46. self,
  47. images: List[Union[np.ndarray, Image.Image]],
  48. batch_size: int = 4,
  49. conf: float = None,
  50. ) -> List:
  51. results = []
  52. with tqdm(total=len(images), desc="MFD Predict") as pbar:
  53. for idx in range(0, len(images), batch_size):
  54. batch = images[idx: idx + batch_size]
  55. batch_preds = self._run_predict(batch, is_batch=True, conf=conf)
  56. results.extend(batch_preds)
  57. pbar.update(len(batch))
  58. return results
  59. def visualize(
  60. self,
  61. image: Union[np.ndarray, Image.Image],
  62. results: List
  63. ) -> Image.Image:
  64. if isinstance(image, np.ndarray):
  65. image = Image.fromarray(image)
  66. formula_list = []
  67. for xyxy, conf, cla in zip(
  68. results.boxes.xyxy.cpu(), results.boxes.conf.cpu(), results.boxes.cls.cpu()
  69. ):
  70. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  71. new_item = {
  72. "category_id": 13 + int(cla.item()),
  73. "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  74. "score": round(float(conf.item()), 2),
  75. }
  76. formula_list.append(new_item)
  77. draw = ImageDraw.Draw(image)
  78. for res in formula_list:
  79. poly = res['poly']
  80. xmin, ymin, xmax, ymax = poly[0], poly[1], poly[4], poly[5]
  81. print(
  82. f"Detected box: {xmin}, {ymin}, {xmax}, {ymax}, Category ID: {res['category_id']}, Score: {res['score']}")
  83. # 使用PIL在图像上画框
  84. draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
  85. # 在框旁边画置信度
  86. draw.text((xmax + 10, ymin + 10), f"{res['score']:.2f}", fill="red", font_size=22)
  87. return image
  88. if __name__ == '__main__':
  89. image_path = r"C:\Users\zhaoxiaomeng\Downloads\screenshot-20250821-192948.png"
  90. yolo_v8_mfd_weights = os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd),
  91. ModelPath.yolo_v8_mfd)
  92. device = 'cuda'
  93. model = YOLOv8MFDModel(
  94. weight=yolo_v8_mfd_weights,
  95. device=device,
  96. )
  97. image = Image.open(image_path)
  98. results = model.predict(image)
  99. image = model.visualize(image, results)
  100. image.show() # 显示图像