model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from pathlib import Path
  3. from ultralytics.engine.model import Model
  4. from ultralytics.models import yolo
  5. from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
  6. from ultralytics.utils import ROOT, yaml_load
  7. class YOLO(Model):
  8. """YOLO (You Only Look Once) object detection model."""
  9. def __init__(self, model="yolov8n.pt", task=None, verbose=False):
  10. """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
  11. path = Path(model)
  12. if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
  13. new_instance = YOLOWorld(path, verbose=verbose)
  14. self.__class__ = type(new_instance)
  15. self.__dict__ = new_instance.__dict__
  16. else:
  17. # Continue with default YOLO initialization
  18. super().__init__(model=model, task=task, verbose=verbose)
  19. @property
  20. def task_map(self):
  21. """Map head to model, trainer, validator, and predictor classes."""
  22. return {
  23. "classify": {
  24. "model": ClassificationModel,
  25. "trainer": yolo.classify.ClassificationTrainer,
  26. "validator": yolo.classify.ClassificationValidator,
  27. "predictor": yolo.classify.ClassificationPredictor,
  28. },
  29. "detect": {
  30. "model": DetectionModel,
  31. "trainer": yolo.detect.DetectionTrainer,
  32. "validator": yolo.detect.DetectionValidator,
  33. "predictor": yolo.detect.DetectionPredictor,
  34. },
  35. "segment": {
  36. "model": SegmentationModel,
  37. "trainer": yolo.segment.SegmentationTrainer,
  38. "validator": yolo.segment.SegmentationValidator,
  39. "predictor": yolo.segment.SegmentationPredictor,
  40. },
  41. "pose": {
  42. "model": PoseModel,
  43. "trainer": yolo.pose.PoseTrainer,
  44. "validator": yolo.pose.PoseValidator,
  45. "predictor": yolo.pose.PosePredictor,
  46. },
  47. "obb": {
  48. "model": OBBModel,
  49. "trainer": yolo.obb.OBBTrainer,
  50. "validator": yolo.obb.OBBValidator,
  51. "predictor": yolo.obb.OBBPredictor,
  52. },
  53. }
  54. class YOLOWorld(Model):
  55. """YOLO-World object detection model."""
  56. def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
  57. """
  58. Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
  59. Args:
  60. model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
  61. """
  62. super().__init__(model=model, task="detect", verbose=verbose)
  63. # Assign default COCO class names when there are no custom names
  64. if not hasattr(self.model, "names"):
  65. self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
  66. @property
  67. def task_map(self):
  68. """Map head to model, validator, and predictor classes."""
  69. return {
  70. "detect": {
  71. "model": WorldModel,
  72. "validator": yolo.detect.DetectionValidator,
  73. "predictor": yolo.detect.DetectionPredictor,
  74. "trainer": yolo.world.WorldTrainer,
  75. }
  76. }
  77. def set_classes(self, classes):
  78. """
  79. Set classes.
  80. Args:
  81. classes (List(str)): A list of categories i.e. ["person"].
  82. """
  83. self.model.set_classes(classes)
  84. # Remove background if it's given
  85. background = " "
  86. if background in classes:
  87. classes.remove(background)
  88. self.model.names = classes
  89. # Reset method class names
  90. # self.predictor = None # reset predictor otherwise old names remain
  91. if self.predictor:
  92. self.predictor.model.names = classes