train.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from copy import copy
  3. import torch
  4. from ultralytics.models.yolo.detect import DetectionTrainer
  5. from ultralytics.nn.tasks import RTDETRDetectionModel
  6. from ultralytics.utils import RANK, colorstr
  7. from .val import RTDETRDataset, RTDETRValidator
  8. class RTDETRTrainer(DetectionTrainer):
  9. """
  10. Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
  11. class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
  12. Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
  13. Notes:
  14. - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
  15. - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
  16. Example:
  17. ```python
  18. from ultralytics.models.rtdetr.train import RTDETRTrainer
  19. args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
  20. trainer = RTDETRTrainer(overrides=args)
  21. trainer.train()
  22. ```
  23. """
  24. def get_model(self, cfg=None, weights=None, verbose=True):
  25. """
  26. Initialize and return an RT-DETR model for object detection tasks.
  27. Args:
  28. cfg (dict, optional): Model configuration. Defaults to None.
  29. weights (str, optional): Path to pre-trained model weights. Defaults to None.
  30. verbose (bool): Verbose logging if True. Defaults to True.
  31. Returns:
  32. (RTDETRDetectionModel): Initialized model.
  33. """
  34. model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
  35. if weights:
  36. model.load(weights)
  37. return model
  38. def build_dataset(self, img_path, mode="val", batch=None):
  39. """
  40. Build and return an RT-DETR dataset for training or validation.
  41. Args:
  42. img_path (str): Path to the folder containing images.
  43. mode (str): Dataset mode, either 'train' or 'val'.
  44. batch (int, optional): Batch size for rectangle training. Defaults to None.
  45. Returns:
  46. (RTDETRDataset): Dataset object for the specific mode.
  47. """
  48. return RTDETRDataset(
  49. img_path=img_path,
  50. is_train_on_platform=self.args.is_train_on_platform,
  51. imgsz=self.args.imgsz,
  52. batch_size=batch,
  53. augment=mode == "train",
  54. hyp=self.args,
  55. rect=False,
  56. cache=self.args.cache or None,
  57. prefix=colorstr(f"{mode}: "),
  58. data=self.data,
  59. )
  60. def get_validator(self):
  61. """
  62. Returns a DetectionValidator suitable for RT-DETR model validation.
  63. Returns:
  64. (RTDETRValidator): Validator object for model validation.
  65. """
  66. self.loss_names = "giou_loss", "cls_loss", "l1_loss"
  67. return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
  68. def preprocess_batch(self, batch):
  69. """
  70. Preprocess a batch of images. Scales and converts the images to float format.
  71. Args:
  72. batch (dict): Dictionary containing a batch of images, bboxes, and labels.
  73. Returns:
  74. (dict): Preprocessed batch.
  75. """
  76. batch = super().preprocess_batch(batch)
  77. bs = len(batch["img"])
  78. batch_idx = batch["batch_idx"]
  79. gt_bbox, gt_class = [], []
  80. for i in range(bs):
  81. gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
  82. gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
  83. return batch