123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 |
- # Ultralytics YOLO 🚀, AGPL-3.0 license
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ultralytics.utils.loss import FocalLoss, VarifocalLoss
- from ultralytics.utils.metrics import bbox_iou
- from .ops import HungarianMatcher
- class DETRLoss(nn.Module):
- """
- DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the
- DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary
- losses.
- Attributes:
- nc (int): The number of classes.
- loss_gain (dict): Coefficients for different loss components.
- aux_loss (bool): Whether to compute auxiliary losses.
- use_fl (bool): Use FocalLoss or not.
- use_vfl (bool): Use VarifocalLoss or not.
- use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch.
- uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True.
- matcher (HungarianMatcher): Object to compute matching cost and indices.
- fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None.
- vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None.
- device (torch.device): Device on which tensors are stored.
- """
- def __init__(
- self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0
- ):
- """
- DETR loss function.
- Args:
- nc (int): The number of classes.
- loss_gain (dict): The coefficient of loss.
- aux_loss (bool): If 'aux_loss = True', loss at each decoder layer are to be used.
- use_vfl (bool): Use VarifocalLoss or not.
- use_uni_match (bool): Whether to use a fixed layer to assign labels for auxiliary branch.
- uni_match_ind (int): The fixed indices of a layer.
- """
- super().__init__()
- if loss_gain is None:
- loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
- self.nc = nc
- self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
- self.loss_gain = loss_gain
- self.aux_loss = aux_loss
- self.fl = FocalLoss() if use_fl else None
- self.vfl = VarifocalLoss() if use_vfl else None
- self.use_uni_match = use_uni_match
- self.uni_match_ind = uni_match_ind
- self.device = None
- def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""):
- """Computes the classification loss based on predictions, target values, and ground truth scores."""
- # Logits: [b, query, num_classes], gt_class: list[[n, 1]]
- name_class = f"loss_class{postfix}"
- bs, nq = pred_scores.shape[:2]
- # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
- one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
- one_hot.scatter_(2, targets.unsqueeze(-1), 1)
- one_hot = one_hot[..., :-1]
- gt_scores = gt_scores.view(bs, nq, 1) * one_hot
- if self.fl:
- if num_gts and self.vfl:
- loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
- else:
- loss_cls = self.fl(pred_scores, one_hot.float())
- loss_cls /= max(num_gts, 1) / nq
- else:
- loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
- return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
- def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""):
- """Calculates and returns the bounding box loss and GIoU loss for the predicted and ground truth bounding
- boxes.
- """
- # Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
- name_bbox = f"loss_bbox{postfix}"
- name_giou = f"loss_giou{postfix}"
- loss = {}
- if len(gt_bboxes) == 0:
- loss[name_bbox] = torch.tensor(0.0, device=self.device)
- loss[name_giou] = torch.tensor(0.0, device=self.device)
- return loss
- loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
- loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
- loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
- loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
- return {k: v.squeeze() for k, v in loss.items()}
- # This function is for future RT-DETR Segment models
- # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
- # # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
- # name_mask = f'loss_mask{postfix}'
- # name_dice = f'loss_dice{postfix}'
- #
- # loss = {}
- # if sum(len(a) for a in gt_mask) == 0:
- # loss[name_mask] = torch.tensor(0., device=self.device)
- # loss[name_dice] = torch.tensor(0., device=self.device)
- # return loss
- #
- # num_gts = len(gt_mask)
- # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
- # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
- # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
- # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
- # torch.tensor([num_gts], dtype=torch.float32))
- # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
- # return loss
- # This function is for future RT-DETR Segment models
- # @staticmethod
- # def _dice_loss(inputs, targets, num_gts):
- # inputs = F.sigmoid(inputs).flatten(1)
- # targets = targets.flatten(1)
- # numerator = 2 * (inputs * targets).sum(1)
- # denominator = inputs.sum(-1) + targets.sum(-1)
- # loss = 1 - (numerator + 1) / (denominator + 1)
- # return loss.sum() / num_gts
- def _get_loss_aux(
- self,
- pred_bboxes,
- pred_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- match_indices=None,
- postfix="",
- masks=None,
- gt_mask=None,
- ):
- """Get auxiliary losses."""
- # NOTE: loss class, bbox, giou, mask, dice
- loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
- if match_indices is None and self.use_uni_match:
- match_indices = self.matcher(
- pred_bboxes[self.uni_match_ind],
- pred_scores[self.uni_match_ind],
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=masks[self.uni_match_ind] if masks is not None else None,
- gt_mask=gt_mask,
- )
- for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
- aux_masks = masks[i] if masks is not None else None
- loss_ = self._get_loss(
- aux_bboxes,
- aux_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=aux_masks,
- gt_mask=gt_mask,
- postfix=postfix,
- match_indices=match_indices,
- )
- loss[0] += loss_[f"loss_class{postfix}"]
- loss[1] += loss_[f"loss_bbox{postfix}"]
- loss[2] += loss_[f"loss_giou{postfix}"]
- # if masks is not None and gt_mask is not None:
- # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
- # loss[3] += loss_[f'loss_mask{postfix}']
- # loss[4] += loss_[f'loss_dice{postfix}']
- loss = {
- f"loss_class_aux{postfix}": loss[0],
- f"loss_bbox_aux{postfix}": loss[1],
- f"loss_giou_aux{postfix}": loss[2],
- }
- # if masks is not None and gt_mask is not None:
- # loss[f'loss_mask_aux{postfix}'] = loss[3]
- # loss[f'loss_dice_aux{postfix}'] = loss[4]
- return loss
- @staticmethod
- def _get_index(match_indices):
- """Returns batch indices, source indices, and destination indices from provided match indices."""
- batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
- src_idx = torch.cat([src for (src, _) in match_indices])
- dst_idx = torch.cat([dst for (_, dst) in match_indices])
- return (batch_idx, src_idx), dst_idx
- def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices):
- """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices."""
- pred_assigned = torch.cat(
- [
- t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
- for t, (i, _) in zip(pred_bboxes, match_indices)
- ]
- )
- gt_assigned = torch.cat(
- [
- t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
- for t, (_, j) in zip(gt_bboxes, match_indices)
- ]
- )
- return pred_assigned, gt_assigned
- def _get_loss(
- self,
- pred_bboxes,
- pred_scores,
- gt_bboxes,
- gt_cls,
- gt_groups,
- masks=None,
- gt_mask=None,
- postfix="",
- match_indices=None,
- ):
- """Get losses."""
- if match_indices is None:
- match_indices = self.matcher(
- pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
- )
- idx, gt_idx = self._get_index(match_indices)
- pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
- bs, nq = pred_scores.shape[:2]
- targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
- targets[idx] = gt_cls[gt_idx]
- gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
- if len(gt_bboxes):
- gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
- loss = {}
- loss.update(self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix))
- loss.update(self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix))
- # if masks is not None and gt_mask is not None:
- # loss.update(self._get_loss_mask(masks, gt_mask, match_indices, postfix))
- return loss
- def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs):
- """
- Args:
- pred_bboxes (torch.Tensor): [l, b, query, 4]
- pred_scores (torch.Tensor): [l, b, query, num_classes]
- batch (dict): A dict includes:
- gt_cls (torch.Tensor) with shape [num_gts, ],
- gt_bboxes (torch.Tensor): [num_gts, 4],
- gt_groups (List(int)): a list of batch size length includes the number of gts of each image.
- postfix (str): postfix of loss name.
- """
- self.device = pred_bboxes.device
- match_indices = kwargs.get("match_indices", None)
- gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
- total_loss = self._get_loss(
- pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
- )
- if self.aux_loss:
- total_loss.update(
- self._get_loss_aux(
- pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
- )
- )
- return total_loss
- class RTDETRDetectionLoss(DETRLoss):
- """
- Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
- This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
- an additional denoising training loss when provided with denoising metadata.
- """
- def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None):
- """
- Forward pass to compute the detection loss.
- Args:
- preds (tuple): Predicted bounding boxes and scores.
- batch (dict): Batch data containing ground truth information.
- dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None.
- dn_scores (torch.Tensor, optional): Denoising scores. Default is None.
- dn_meta (dict, optional): Metadata for denoising. Default is None.
- Returns:
- (dict): Dictionary containing the total loss and, if applicable, the denoising loss.
- """
- pred_bboxes, pred_scores = preds
- total_loss = super().forward(pred_bboxes, pred_scores, batch)
- # Check for denoising metadata to compute denoising training loss
- if dn_meta is not None:
- dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
- assert len(batch["gt_groups"]) == len(dn_pos_idx)
- # Get the match indices for denoising
- match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
- # Compute the denoising training loss
- dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
- total_loss.update(dn_loss)
- else:
- # If no denoising metadata is provided, set denoising loss to zero
- total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
- return total_loss
- @staticmethod
- def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups):
- """
- Get the match indices for denoising.
- Args:
- dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising.
- dn_num_group (int): Number of denoising groups.
- gt_groups (List[int]): List of integers representing the number of ground truths for each image.
- Returns:
- (List[tuple]): List of tuples containing matched indices for denoising.
- """
- dn_match_indices = []
- idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
- for i, num_gt in enumerate(gt_groups):
- if num_gt > 0:
- gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
- gt_idx = gt_idx.repeat(dn_num_group)
- assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, "
- f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
- dn_match_indices.append((dn_pos_idx[i], gt_idx))
- else:
- dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
- return dn_match_indices
|