123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- from copy import copy
- import torch
- from ultralytics.models.yolo.detect import DetectionTrainer
- from ultralytics.nn.tasks import RTDETRDetectionModel
- from ultralytics.utils import RANK, colorstr
- from .val import RTDETRDataset, RTDETRValidator
- class RTDETRTrainer(DetectionTrainer):
- """
- Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
- class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
- Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
- Notes:
- - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
- - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
- Example:
- ```python
- from ultralytics.models.rtdetr.train import RTDETRTrainer
- args = dict(model='rtdetr-l.yaml', data='coco8.yaml', imgsz=640, epochs=3)
- trainer = RTDETRTrainer(overrides=args)
- trainer.train()
- ```
- """
- def get_model(self, cfg=None, weights=None, verbose=True):
- """
- Initialize and return an RT-DETR model for object detection tasks.
- Args:
- cfg (dict, optional): Model configuration. Defaults to None.
- weights (str, optional): Path to pre-trained model weights. Defaults to None.
- verbose (bool): Verbose logging if True. Defaults to True.
- Returns:
- (RTDETRDetectionModel): Initialized model.
- """
- model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
- if weights:
- model.load(weights)
- return model
- def build_dataset(self, img_path, mode="val", batch=None):
- """
- Build and return an RT-DETR dataset for training or validation.
- Args:
- img_path (str): Path to the folder containing images.
- mode (str): Dataset mode, either 'train' or 'val'.
- batch (int, optional): Batch size for rectangle training. Defaults to None.
- Returns:
- (RTDETRDataset): Dataset object for the specific mode.
- """
- return RTDETRDataset(
- img_path=img_path,
- is_train_on_platform=self.args.is_train_on_platform,
- imgsz=self.args.imgsz,
- batch_size=batch,
- augment=mode == "train",
- hyp=self.args,
- rect=False,
- cache=self.args.cache or None,
- prefix=colorstr(f"{mode}: "),
- data=self.data,
- )
- def get_validator(self):
- """
- Returns a DetectionValidator suitable for RT-DETR model validation.
- Returns:
- (RTDETRValidator): Validator object for model validation.
- """
- self.loss_names = "giou_loss", "cls_loss", "l1_loss"
- return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
- def preprocess_batch(self, batch):
- """
- Preprocess a batch of images. Scales and converts the images to float format.
- Args:
- batch (dict): Dictionary containing a batch of images, bboxes, and labels.
- Returns:
- (dict): Preprocessed batch.
- """
- batch = super().preprocess_batch(batch)
- bs = len(batch["img"])
- batch_idx = batch["batch_idx"]
- gt_bbox, gt_class = [], []
- for i in range(bs):
- gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
- gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
- return batch
|