import torch import torch.nn as nn import torch.nn.functional as F def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): """ Select the positive anchor center in gt. Args: xy_centers (Tensor): shape(h*w, 2) gt_bboxes (Tensor): shape(b, n_boxes, 4) Returns: (Tensor): shape(b, n_boxes, h*w) """ n_anchors = xy_centers.shape[0] bs, n_boxes, _ = gt_bboxes.shape lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) return bbox_deltas.amin(3).gt_(eps) def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): """ If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected. Args: mask_pos (Tensor): shape(b, n_max_boxes, h*w) overlaps (Tensor): shape(b, n_max_boxes, h*w) Returns: target_gt_idx (Tensor): shape(b, h*w) fg_mask (Tensor): shape(b, h*w) mask_pos (Tensor): shape(b, n_max_boxes, h*w) """ # (b, n_max_boxes, h*w) -> (b, h*w) fg_mask = mask_pos.sum(-2) if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) max_overlaps_idx = overlaps.argmax(1) # (b, h*w) is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) fg_mask = mask_pos.sum(-2) # Find each grid serve which gt(index) target_gt_idx = mask_pos.argmax(-2) # (b, h*w) return target_gt_idx, fg_mask, mask_pos def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5, device='cpu', is_eval=False, mode='af'): '''Generate anchors from features.''' anchors = [] anchor_points = [] stride_tensor = [] num_anchors_list = [] assert feats is not None if is_eval: for i, stride in enumerate(fpn_strides): _, _, h, w = feats[i].shape shift_x = torch.arange(end=w, device=device) + grid_cell_offset shift_y = torch.arange(end=h, device=device) + grid_cell_offset shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') anchor_point = torch.stack( [shift_x, shift_y], axis=-1).to(torch.float) if mode == 'af': # anchor-free anchor_points.append(anchor_point.reshape([-1, 2])) stride_tensor.append( torch.full( (h * w, 1), stride, dtype=torch.float, device=device)) elif mode == 'ab': # anchor-based anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1)) stride_tensor.append( torch.full( (h * w, 1), stride, dtype=torch.float, device=device).repeat(3,1)) anchor_points = torch.cat(anchor_points) stride_tensor = torch.cat(stride_tensor) return anchor_points, stride_tensor else: for i, stride in enumerate(fpn_strides): _, _, h, w = feats[i].shape cell_half_size = grid_cell_size * stride * 0.5 shift_x = (torch.arange(end=w, device=device) + grid_cell_offset) * stride shift_y = (torch.arange(end=h, device=device) + grid_cell_offset) * stride shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij') anchor = torch.stack( [ shift_x - cell_half_size, shift_y - cell_half_size, shift_x + cell_half_size, shift_y + cell_half_size ], axis=-1).clone().to(feats[0].dtype) anchor_point = torch.stack( [shift_x, shift_y], axis=-1).clone().to(feats[0].dtype) if mode == 'af': # anchor-free anchors.append(anchor.reshape([-1, 4])) anchor_points.append(anchor_point.reshape([-1, 2])) elif mode == 'ab': # anchor-based anchors.append(anchor.reshape([-1, 4]).repeat(3,1)) anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1)) num_anchors_list.append(len(anchors[-1])) stride_tensor.append( torch.full( [num_anchors_list[-1], 1], stride, dtype=feats[0].dtype)) anchors = torch.cat(anchors) anchor_points = torch.cat(anchor_points).to(device) stride_tensor = torch.cat(stride_tensor).to(device) return anchors, anchor_points, num_anchors_list, stride_tensor def fp16_clamp(x, min=None, max=None): if not x.is_cuda and x.dtype == torch.float16: # clamp for cpu float16, tensor fp16 has no clamp implementation return x.float().clamp(min, max).half() return x.clamp(min, max) def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): """Calculate overlap between two set of bboxes. FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 Note: Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', there are some new generated variable when calculating IOU using bbox_overlaps function: 1) is_aligned is False area1: M x 1 area2: N x 1 lt: M x N x 2 rb: M x N x 2 wh: M x N x 2 overlap: M x N x 1 union: M x N x 1 ious: M x N x 1 Total memory: S = (9 x N x M + N + M) * 4 Byte, When using FP16, we can reduce: R = (9 x N x M + N + M) * 4 / 2 Byte R large than (N + M) * 4 * 2 is always true when N and M >= 1. Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, N + 1 < 3 * N, when N or M is 1. Given M = 40 (ground truth), N = 400000 (three anchor boxes in per grid, FPN, R-CNNs), R = 275 MB (one times) A special case (dense detection), M = 512 (ground truth), R = 3516 MB = 3.43 GB When the batch size is B, reduce: B x R Therefore, CUDA memory runs out frequently. Experiments on GeForce RTX 2080Ti (11019 MiB): | dtype | M | N | Use | Real | Ideal | |:----:|:----:|:----:|:----:|:----:|:----:| | FP32 | 512 | 400000 | 8020 MiB | -- | -- | | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | | FP32 | 40 | 400000 | 1540 MiB | -- | -- | | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | 2) is_aligned is True area1: N x 1 area2: N x 1 lt: N x 2 rb: N x 2 wh: N x 2 overlap: N x 1 union: N x 1 ious: N x 1 Total memory: S = 11 x N * 4 Byte When using FP16, we can reduce: R = 11 x N * 4 / 2 Byte So do the 'giou' (large than 'iou'). Time-wise, FP16 is generally faster than FP32. When gpu_assign_thr is not -1, it takes more time on cpu but not reduce memory. There, we can reduce half the memory and keep the speed. If ``is_aligned`` is ``False``, then calculate the overlaps between each bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of bboxes1 and bboxes2. Args: bboxes1 (Tensor): shape (B, m, 4) in format or empty. bboxes2 (Tensor): shape (B, n, 4) in format or empty. B indicates the batch dim, in shape (B1, B2, ..., Bn). If ``is_aligned`` is ``True``, then m and n must be equal. mode (str): "iou" (intersection over union), "iof" (intersection over foreground) or "giou" (generalized intersection over union). Default "iou". is_aligned (bool, optional): If True, then m and n must be equal. Default False. eps (float, optional): A value added to the denominator for numerical stability. Default 1e-6. Returns: Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) Example: >>> bboxes1 = torch.FloatTensor([ >>> [0, 0, 10, 10], >>> [10, 10, 20, 20], >>> [32, 32, 38, 42], >>> ]) >>> bboxes2 = torch.FloatTensor([ >>> [0, 0, 10, 20], >>> [0, 10, 10, 19], >>> [10, 10, 20, 20], >>> ]) >>> overlaps = bbox_overlaps(bboxes1, bboxes2) >>> assert overlaps.shape == (3, 3) >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) >>> assert overlaps.shape == (3, ) Example: >>> empty = torch.empty(0, 4) >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) """ assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' # Either the boxes are empty or the length of boxes' last dimension is 4 assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) # Batch dim must be the same # Batch dim: (B1, B2, ... Bn) assert bboxes1.shape[:-2] == bboxes2.shape[:-2] batch_shape = bboxes1.shape[:-2] rows = bboxes1.size(-2) cols = bboxes2.size(-2) if is_aligned: assert rows == cols if rows * cols == 0: if is_aligned: return bboxes1.new(batch_shape + (rows, )) else: return bboxes1.new(batch_shape + (rows, cols)) area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( bboxes1[..., 3] - bboxes1[..., 1]) area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( bboxes2[..., 3] - bboxes2[..., 1]) if is_aligned: lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1 + area2 - overlap else: union = area1 if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) else: lt = torch.max(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) # [B, rows, cols, 2] rb = torch.min(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] wh = fp16_clamp(rb - lt, min=0) overlap = wh[..., 0] * wh[..., 1] if mode in ['iou', 'giou']: union = area1[..., None] + area2[..., None, :] - overlap else: union = area1[..., None] if mode == 'giou': enclosed_lt = torch.min(bboxes1[..., :, None, :2], bboxes2[..., None, :, :2]) enclosed_rb = torch.max(bboxes1[..., :, None, 2:], bboxes2[..., None, :, 2:]) eps = union.new_tensor([eps]) union = torch.max(union, eps) ious = overlap / union if mode in ['iou', 'iof']: return ious # calculate gious enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] enclose_area = torch.max(enclose_area, eps) gious = ious - (enclose_area - union) / enclose_area return gious def cast_tensor_type(x, scale=1., dtype=None): if dtype == 'fp16': # scale is for preventing overflows x = (x / scale).half() return x def iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None): """2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" """Calculate IoU between 2D bboxes. Args: bboxes1 (Tensor): bboxes have shape (m, 4) in format, or shape (m, 5) in format. bboxes2 (Tensor): bboxes have shape (m, 4) in format, shape (m, 5) in format, or be empty. If ``is_aligned `` is ``True``, then m and n must be equal. mode (str): "iou" (intersection over union), "iof" (intersection over foreground), or "giou" (generalized intersection over union). is_aligned (bool, optional): If True, then m and n must be equal. Default False. Returns: Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) """ assert bboxes1.size(-1) in [0, 4, 5] assert bboxes2.size(-1) in [0, 4, 5] if bboxes2.size(-1) == 5: bboxes2 = bboxes2[..., :4] if bboxes1.size(-1) == 5: bboxes1 = bboxes1[..., :4] if dtype == 'fp16': # change tensor type to save cpu and cuda memory and keep speed bboxes1 = cast_tensor_type(bboxes1, scale, dtype) bboxes2 = cast_tensor_type(bboxes2, scale, dtype) overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) if not overlaps.is_cuda and overlaps.dtype == torch.float16: # resume cpu float32 overlaps = overlaps.float() return overlaps return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) def dist_calculator(gt_bboxes, anchor_bboxes): """compute center distance between all bbox and gt Args: gt_bboxes (Tensor): shape(bs*n_max_boxes, 4) anchor_bboxes (Tensor): shape(num_total_anchors, 4) Return: distances (Tensor): shape(bs*n_max_boxes, num_total_anchors) ac_points (Tensor): shape(num_total_anchors, 2) """ gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0 gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0 gt_points = torch.stack([gt_cx, gt_cy], dim=1) ac_cx = (anchor_bboxes[:, 0] + anchor_bboxes[:, 2]) / 2.0 ac_cy = (anchor_bboxes[:, 1] + anchor_bboxes[:, 3]) / 2.0 ac_points = torch.stack([ac_cx, ac_cy], dim=1) distances = (gt_points[:, None, :] - ac_points[None, :, :]).pow(2).sum(-1).sqrt() return distances, ac_points def iou_calculator(box1, box2, eps=1e-9): """Calculate iou for batch Args: box1 (Tensor): shape(bs, n_max_boxes, 1, 4) box2 (Tensor): shape(bs, 1, num_total_anchors, 4) Return: (Tensor): shape(bs, n_max_boxes, num_total_anchors) """ box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4] box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4] px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4] gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4] x1y1 = torch.maximum(px1y1, gx1y1) x2y2 = torch.minimum(px2y2, gx2y2) overlap = (x2y2 - x1y1).clip(0).prod(-1) area1 = (px2y2 - px1y1).clip(0).prod(-1) area2 = (gx2y2 - gx1y1).clip(0).prod(-1) union = area1 + area2 - overlap + eps return overlap / union class ATSSAssigner(nn.Module): '''Adaptive Training Sample Selection Assigner''' def __init__(self, topk=9, num_classes=80): super(ATSSAssigner, self).__init__() self.topk = topk self.num_classes = num_classes self.bg_idx = num_classes @torch.no_grad() def forward(self, anc_bboxes, n_level_bboxes, gt_labels, gt_bboxes, mask_gt, pd_bboxes): r"""This code is based on https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py Args: anc_bboxes (Tensor): shape(num_total_anchors, 4) n_level_bboxes (List):len(3) gt_labels (Tensor): shape(bs, n_max_boxes, 1) gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) mask_gt (Tensor): shape(bs, n_max_boxes, 1) pd_bboxes (Tensor): shape(bs, n_max_boxes, 4) Returns: target_labels (Tensor): shape(bs, num_total_anchors) target_bboxes (Tensor): shape(bs, num_total_anchors, 4) target_scores (Tensor): shape(bs, num_total_anchors, num_classes) fg_mask (Tensor): shape(bs, num_total_anchors) """ self.n_anchors = anc_bboxes.size(0) self.bs = gt_bboxes.size(0) self.n_max_boxes = gt_bboxes.size(1) if self.n_max_boxes == 0: device = gt_bboxes.device return torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \ torch.zeros([self.bs, self.n_anchors, 4]).to(device), \ torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \ torch.zeros([self.bs, self.n_anchors]).to(device), \ torch.zeros([self.bs, self.n_anchors]).to(device) overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes) overlaps = overlaps.reshape([self.bs, -1, self.n_anchors]) distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes) distances = distances.reshape([self.bs, -1, self.n_anchors]) is_in_candidate, candidate_idxs = self.select_topk_candidates( distances, n_level_bboxes, mask_gt) overlaps_thr_per_gt, iou_candidates = self.thres_calculator( is_in_candidate, candidate_idxs, overlaps) # select candidates iou >= threshold as positive is_pos = torch.where( iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]), is_in_candidate, torch.zeros_like(is_in_candidate)) is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes) mask_pos = is_pos * is_in_gts * mask_gt target_gt_idx, fg_mask, mask_pos = select_highest_overlaps( mask_pos, overlaps, self.n_max_boxes) # assigned target target_labels, target_bboxes, target_scores = self.get_targets( gt_labels, gt_bboxes, target_gt_idx, fg_mask) # soft label with iou if pd_bboxes is not None: ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_pos ious = ious.max(axis=-2)[0].unsqueeze(-1) target_scores *= ious return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx def select_topk_candidates(self, distances, n_level_bboxes, mask_gt): mask_gt = mask_gt.repeat(1, 1, self.topk).bool() level_distances = torch.split(distances, n_level_bboxes, dim=-1) is_in_candidate_list = [] candidate_idxs = [] start_idx = 0 for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes): end_idx = start_idx + per_level_boxes selected_k = min(self.topk, per_level_boxes) _, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False) candidate_idxs.append(per_level_topk_idxs + start_idx) per_level_topk_idxs = torch.where(mask_gt, per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs)) is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2) is_in_candidate = torch.where(is_in_candidate > 1, torch.zeros_like(is_in_candidate), is_in_candidate) is_in_candidate_list.append(is_in_candidate.to(distances.dtype)) start_idx = end_idx is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1) candidate_idxs = torch.cat(candidate_idxs, dim=-1) return is_in_candidate_list, candidate_idxs def thres_calculator(self, is_in_candidate, candidate_idxs, overlaps): n_bs_max_boxes = self.bs * self.n_max_boxes _candidate_overlaps = torch.where(is_in_candidate > 0, overlaps, torch.zeros_like(overlaps)) candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1]) assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device) assist_idxs = assist_idxs[:,None] faltten_idxs = candidate_idxs + assist_idxs candidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs] candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1]) overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True) overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True) overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt return overlaps_thr_per_gt, _candidate_overlaps def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): # assigned target labels batch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device) batch_idx = batch_idx[..., None] target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long() target_labels = gt_labels.flatten()[target_gt_idx.flatten()] target_labels = target_labels.reshape([self.bs, self.n_anchors]) target_labels = torch.where(fg_mask > 0, target_labels, torch.full_like(target_labels, self.bg_idx)) # assigned target boxes target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()] target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4]) # assigned target scores target_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float() target_scores = target_scores[:, :, :self.num_classes] return target_labels, target_bboxes, target_scores