123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- from pathlib import Path
- from ultralytics.engine.model import Model
- from ultralytics.models import yolo
- from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel
- from ultralytics.utils import ROOT, yaml_load
- class YOLO(Model):
- """YOLO (You Only Look Once) object detection model."""
- def __init__(self, model="yolov8n.pt", task=None, verbose=False):
- """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'."""
- path = Path(model)
- if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
- new_instance = YOLOWorld(path, verbose=verbose)
- self.__class__ = type(new_instance)
- self.__dict__ = new_instance.__dict__
- else:
- # Continue with default YOLO initialization
- super().__init__(model=model, task=task, verbose=verbose)
- @property
- def task_map(self):
- """Map head to model, trainer, validator, and predictor classes."""
- return {
- "classify": {
- "model": ClassificationModel,
- "trainer": yolo.classify.ClassificationTrainer,
- "validator": yolo.classify.ClassificationValidator,
- "predictor": yolo.classify.ClassificationPredictor,
- },
- "detect": {
- "model": DetectionModel,
- "trainer": yolo.detect.DetectionTrainer,
- "validator": yolo.detect.DetectionValidator,
- "predictor": yolo.detect.DetectionPredictor,
- },
- "segment": {
- "model": SegmentationModel,
- "trainer": yolo.segment.SegmentationTrainer,
- "validator": yolo.segment.SegmentationValidator,
- "predictor": yolo.segment.SegmentationPredictor,
- },
- "pose": {
- "model": PoseModel,
- "trainer": yolo.pose.PoseTrainer,
- "validator": yolo.pose.PoseValidator,
- "predictor": yolo.pose.PosePredictor,
- },
- "obb": {
- "model": OBBModel,
- "trainer": yolo.obb.OBBTrainer,
- "validator": yolo.obb.OBBValidator,
- "predictor": yolo.obb.OBBPredictor,
- },
- }
- class YOLOWorld(Model):
- """YOLO-World object detection model."""
- def __init__(self, model="yolov8s-world.pt", verbose=False) -> None:
- """
- Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats.
- Args:
- model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'.
- """
- super().__init__(model=model, task="detect", verbose=verbose)
- # Assign default COCO class names when there are no custom names
- if not hasattr(self.model, "names"):
- self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names")
- @property
- def task_map(self):
- """Map head to model, validator, and predictor classes."""
- return {
- "detect": {
- "model": WorldModel,
- "validator": yolo.detect.DetectionValidator,
- "predictor": yolo.detect.DetectionPredictor,
- "trainer": yolo.world.WorldTrainer,
- }
- }
- def set_classes(self, classes):
- """
- Set classes.
- Args:
- classes (List(str)): A list of categories i.e. ["person"].
- """
- self.model.set_classes(classes)
- # Remove background if it's given
- background = " "
- if background in classes:
- classes.remove(background)
- self.model.names = classes
- # Reset method class names
- # self.predictor = None # reset predictor otherwise old names remain
- if self.predictor:
- self.predictor.model.names = classes
|