model.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
  4. performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
  5. hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
  6. For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
  7. """
  8. from ultralytics.engine.model import Model
  9. from ultralytics.nn.tasks import RTDETRDetectionModel
  10. from .predict import RTDETRPredictor
  11. from .train import RTDETRTrainer
  12. from .val import RTDETRValidator
  13. class RTDETR(Model):
  14. """
  15. Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
  16. with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
  17. Attributes:
  18. model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
  19. """
  20. def __init__(self, model="rtdetr-l.pt") -> None:
  21. """
  22. Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
  23. Args:
  24. model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
  25. Raises:
  26. NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
  27. """
  28. super().__init__(model=model, task="detect")
  29. @property
  30. def task_map(self) -> dict:
  31. """
  32. Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
  33. Returns:
  34. dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
  35. """
  36. return {
  37. "detect": {
  38. "predictor": RTDETRPredictor,
  39. "validator": RTDETRValidator,
  40. "trainer": RTDETRTrainer,
  41. "model": RTDETRDetectionModel,
  42. }
  43. }