123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
- """
- Select the positive anchor center in gt.
- Args:
- xy_centers (Tensor): shape(h*w, 2)
- gt_bboxes (Tensor): shape(b, n_boxes, 4)
- Returns:
- (Tensor): shape(b, n_boxes, h*w)
- """
- n_anchors = xy_centers.shape[0]
- bs, n_boxes, _ = gt_bboxes.shape
- lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
- bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
- # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
- return bbox_deltas.amin(3).gt_(eps)
- def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
- """
- If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
- Args:
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
- overlaps (Tensor): shape(b, n_max_boxes, h*w)
- Returns:
- target_gt_idx (Tensor): shape(b, h*w)
- fg_mask (Tensor): shape(b, h*w)
- mask_pos (Tensor): shape(b, n_max_boxes, h*w)
- """
- # (b, n_max_boxes, h*w) -> (b, h*w)
- fg_mask = mask_pos.sum(-2)
- if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
- mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
- max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
- is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
- is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
- mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
- fg_mask = mask_pos.sum(-2)
- # Find each grid serve which gt(index)
- target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
- return target_gt_idx, fg_mask, mask_pos
- def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5, device='cpu', is_eval=False, mode='af'):
- '''Generate anchors from features.'''
- anchors = []
- anchor_points = []
- stride_tensor = []
- num_anchors_list = []
- assert feats is not None
- if is_eval:
- for i, stride in enumerate(fpn_strides):
- _, _, h, w = feats[i].shape
- shift_x = torch.arange(end=w, device=device) + grid_cell_offset
- shift_y = torch.arange(end=h, device=device) + grid_cell_offset
- shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
- anchor_point = torch.stack(
- [shift_x, shift_y], axis=-1).to(torch.float)
- if mode == 'af': # anchor-free
- anchor_points.append(anchor_point.reshape([-1, 2]))
- stride_tensor.append(
- torch.full(
- (h * w, 1), stride, dtype=torch.float, device=device))
- elif mode == 'ab': # anchor-based
- anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1))
- stride_tensor.append(
- torch.full(
- (h * w, 1), stride, dtype=torch.float, device=device).repeat(3,1))
- anchor_points = torch.cat(anchor_points)
- stride_tensor = torch.cat(stride_tensor)
- return anchor_points, stride_tensor
- else:
- for i, stride in enumerate(fpn_strides):
- _, _, h, w = feats[i].shape
- cell_half_size = grid_cell_size * stride * 0.5
- shift_x = (torch.arange(end=w, device=device) + grid_cell_offset) * stride
- shift_y = (torch.arange(end=h, device=device) + grid_cell_offset) * stride
- shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
- anchor = torch.stack(
- [
- shift_x - cell_half_size, shift_y - cell_half_size,
- shift_x + cell_half_size, shift_y + cell_half_size
- ],
- axis=-1).clone().to(feats[0].dtype)
- anchor_point = torch.stack(
- [shift_x, shift_y], axis=-1).clone().to(feats[0].dtype)
- if mode == 'af': # anchor-free
- anchors.append(anchor.reshape([-1, 4]))
- anchor_points.append(anchor_point.reshape([-1, 2]))
- elif mode == 'ab': # anchor-based
- anchors.append(anchor.reshape([-1, 4]).repeat(3,1))
- anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1))
- num_anchors_list.append(len(anchors[-1]))
- stride_tensor.append(
- torch.full(
- [num_anchors_list[-1], 1], stride, dtype=feats[0].dtype))
- anchors = torch.cat(anchors)
- anchor_points = torch.cat(anchor_points).to(device)
- stride_tensor = torch.cat(stride_tensor).to(device)
- return anchors, anchor_points, num_anchors_list, stride_tensor
- def fp16_clamp(x, min=None, max=None):
- if not x.is_cuda and x.dtype == torch.float16:
- # clamp for cpu float16, tensor fp16 has no clamp implementation
- return x.float().clamp(min, max).half()
- return x.clamp(min, max)
- def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
- """Calculate overlap between two set of bboxes.
- FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
- Note:
- Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
- there are some new generated variable when calculating IOU
- using bbox_overlaps function:
- 1) is_aligned is False
- area1: M x 1
- area2: N x 1
- lt: M x N x 2
- rb: M x N x 2
- wh: M x N x 2
- overlap: M x N x 1
- union: M x N x 1
- ious: M x N x 1
- Total memory:
- S = (9 x N x M + N + M) * 4 Byte,
- When using FP16, we can reduce:
- R = (9 x N x M + N + M) * 4 / 2 Byte
- R large than (N + M) * 4 * 2 is always true when N and M >= 1.
- Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
- N + 1 < 3 * N, when N or M is 1.
- Given M = 40 (ground truth), N = 400000 (three anchor boxes
- in per grid, FPN, R-CNNs),
- R = 275 MB (one times)
- A special case (dense detection), M = 512 (ground truth),
- R = 3516 MB = 3.43 GB
- When the batch size is B, reduce:
- B x R
- Therefore, CUDA memory runs out frequently.
- Experiments on GeForce RTX 2080Ti (11019 MiB):
- | dtype | M | N | Use | Real | Ideal |
- |:----:|:----:|:----:|:----:|:----:|:----:|
- | FP32 | 512 | 400000 | 8020 MiB | -- | -- |
- | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
- | FP32 | 40 | 400000 | 1540 MiB | -- | -- |
- | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
- 2) is_aligned is True
- area1: N x 1
- area2: N x 1
- lt: N x 2
- rb: N x 2
- wh: N x 2
- overlap: N x 1
- union: N x 1
- ious: N x 1
- Total memory:
- S = 11 x N * 4 Byte
- When using FP16, we can reduce:
- R = 11 x N * 4 / 2 Byte
- So do the 'giou' (large than 'iou').
- Time-wise, FP16 is generally faster than FP32.
- When gpu_assign_thr is not -1, it takes more time on cpu
- but not reduce memory.
- There, we can reduce half the memory and keep the speed.
- If ``is_aligned`` is ``False``, then calculate the overlaps between each
- bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
- pair of bboxes1 and bboxes2.
- Args:
- bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
- bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
- B indicates the batch dim, in shape (B1, B2, ..., Bn).
- If ``is_aligned`` is ``True``, then m and n must be equal.
- mode (str): "iou" (intersection over union), "iof" (intersection over
- foreground) or "giou" (generalized intersection over union).
- Default "iou".
- is_aligned (bool, optional): If True, then m and n must be equal.
- Default False.
- eps (float, optional): A value added to the denominator for numerical
- stability. Default 1e-6.
- Returns:
- Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
- Example:
- >>> bboxes1 = torch.FloatTensor([
- >>> [0, 0, 10, 10],
- >>> [10, 10, 20, 20],
- >>> [32, 32, 38, 42],
- >>> ])
- >>> bboxes2 = torch.FloatTensor([
- >>> [0, 0, 10, 20],
- >>> [0, 10, 10, 19],
- >>> [10, 10, 20, 20],
- >>> ])
- >>> overlaps = bbox_overlaps(bboxes1, bboxes2)
- >>> assert overlaps.shape == (3, 3)
- >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
- >>> assert overlaps.shape == (3, )
- Example:
- >>> empty = torch.empty(0, 4)
- >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
- >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
- >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
- >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
- """
- assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
- # Either the boxes are empty or the length of boxes' last dimension is 4
- assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
- assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
- # Batch dim must be the same
- # Batch dim: (B1, B2, ... Bn)
- assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
- batch_shape = bboxes1.shape[:-2]
- rows = bboxes1.size(-2)
- cols = bboxes2.size(-2)
- if is_aligned:
- assert rows == cols
- if rows * cols == 0:
- if is_aligned:
- return bboxes1.new(batch_shape + (rows, ))
- else:
- return bboxes1.new(batch_shape + (rows, cols))
- area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
- bboxes1[..., 3] - bboxes1[..., 1])
- area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
- bboxes2[..., 3] - bboxes2[..., 1])
- if is_aligned:
- lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
- rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
- wh = fp16_clamp(rb - lt, min=0)
- overlap = wh[..., 0] * wh[..., 1]
- if mode in ['iou', 'giou']:
- union = area1 + area2 - overlap
- else:
- union = area1
- if mode == 'giou':
- enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
- enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
- else:
- lt = torch.max(bboxes1[..., :, None, :2],
- bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
- rb = torch.min(bboxes1[..., :, None, 2:],
- bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
- wh = fp16_clamp(rb - lt, min=0)
- overlap = wh[..., 0] * wh[..., 1]
- if mode in ['iou', 'giou']:
- union = area1[..., None] + area2[..., None, :] - overlap
- else:
- union = area1[..., None]
- if mode == 'giou':
- enclosed_lt = torch.min(bboxes1[..., :, None, :2],
- bboxes2[..., None, :, :2])
- enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
- bboxes2[..., None, :, 2:])
- eps = union.new_tensor([eps])
- union = torch.max(union, eps)
- ious = overlap / union
- if mode in ['iou', 'iof']:
- return ious
- # calculate gious
- enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
- enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
- enclose_area = torch.max(enclose_area, eps)
- gious = ious - (enclose_area - union) / enclose_area
- return gious
- def cast_tensor_type(x, scale=1., dtype=None):
- if dtype == 'fp16':
- # scale is for preventing overflows
- x = (x / scale).half()
- return x
- def iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None):
- """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
- """Calculate IoU between 2D bboxes.
- Args:
- bboxes1 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>
- format, or shape (m, 5) in <x1, y1, x2, y2, score> format.
- bboxes2 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>
- format, shape (m, 5) in <x1, y1, x2, y2, score> format, or be
- empty. If ``is_aligned `` is ``True``, then m and n must be
- equal.
- mode (str): "iou" (intersection over union), "iof" (intersection
- over foreground), or "giou" (generalized intersection over
- union).
- is_aligned (bool, optional): If True, then m and n must be equal.
- Default False.
- Returns:
- Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
- """
- assert bboxes1.size(-1) in [0, 4, 5]
- assert bboxes2.size(-1) in [0, 4, 5]
- if bboxes2.size(-1) == 5:
- bboxes2 = bboxes2[..., :4]
- if bboxes1.size(-1) == 5:
- bboxes1 = bboxes1[..., :4]
- if dtype == 'fp16':
- # change tensor type to save cpu and cuda memory and keep speed
- bboxes1 = cast_tensor_type(bboxes1, scale, dtype)
- bboxes2 = cast_tensor_type(bboxes2, scale, dtype)
- overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
- if not overlaps.is_cuda and overlaps.dtype == torch.float16:
- # resume cpu float32
- overlaps = overlaps.float()
- return overlaps
- return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
- def dist_calculator(gt_bboxes, anchor_bboxes):
- """compute center distance between all bbox and gt
- Args:
- gt_bboxes (Tensor): shape(bs*n_max_boxes, 4)
- anchor_bboxes (Tensor): shape(num_total_anchors, 4)
- Return:
- distances (Tensor): shape(bs*n_max_boxes, num_total_anchors)
- ac_points (Tensor): shape(num_total_anchors, 2)
- """
- gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
- gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
- gt_points = torch.stack([gt_cx, gt_cy], dim=1)
- ac_cx = (anchor_bboxes[:, 0] + anchor_bboxes[:, 2]) / 2.0
- ac_cy = (anchor_bboxes[:, 1] + anchor_bboxes[:, 3]) / 2.0
- ac_points = torch.stack([ac_cx, ac_cy], dim=1)
- distances = (gt_points[:, None, :] - ac_points[None, :, :]).pow(2).sum(-1).sqrt()
- return distances, ac_points
- def iou_calculator(box1, box2, eps=1e-9):
- """Calculate iou for batch
- Args:
- box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
- box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
- Return:
- (Tensor): shape(bs, n_max_boxes, num_total_anchors)
- """
- box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
- box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
- px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
- gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
- x1y1 = torch.maximum(px1y1, gx1y1)
- x2y2 = torch.minimum(px2y2, gx2y2)
- overlap = (x2y2 - x1y1).clip(0).prod(-1)
- area1 = (px2y2 - px1y1).clip(0).prod(-1)
- area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
- union = area1 + area2 - overlap + eps
- return overlap / union
- class ATSSAssigner(nn.Module):
- '''Adaptive Training Sample Selection Assigner'''
- def __init__(self,
- topk=9,
- num_classes=80):
- super(ATSSAssigner, self).__init__()
- self.topk = topk
- self.num_classes = num_classes
- self.bg_idx = num_classes
- @torch.no_grad()
- def forward(self,
- anc_bboxes,
- n_level_bboxes,
- gt_labels,
- gt_bboxes,
- mask_gt,
- pd_bboxes):
- r"""This code is based on
- https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
- Args:
- anc_bboxes (Tensor): shape(num_total_anchors, 4)
- n_level_bboxes (List):len(3)
- gt_labels (Tensor): shape(bs, n_max_boxes, 1)
- gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
- mask_gt (Tensor): shape(bs, n_max_boxes, 1)
- pd_bboxes (Tensor): shape(bs, n_max_boxes, 4)
- Returns:
- target_labels (Tensor): shape(bs, num_total_anchors)
- target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
- target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
- fg_mask (Tensor): shape(bs, num_total_anchors)
- """
- self.n_anchors = anc_bboxes.size(0)
- self.bs = gt_bboxes.size(0)
- self.n_max_boxes = gt_bboxes.size(1)
- if self.n_max_boxes == 0:
- device = gt_bboxes.device
- return torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \
- torch.zeros([self.bs, self.n_anchors, 4]).to(device), \
- torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \
- torch.zeros([self.bs, self.n_anchors]).to(device), \
- torch.zeros([self.bs, self.n_anchors]).to(device)
- overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
- overlaps = overlaps.reshape([self.bs, -1, self.n_anchors])
- distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
- distances = distances.reshape([self.bs, -1, self.n_anchors])
- is_in_candidate, candidate_idxs = self.select_topk_candidates(
- distances, n_level_bboxes, mask_gt)
- overlaps_thr_per_gt, iou_candidates = self.thres_calculator(
- is_in_candidate, candidate_idxs, overlaps)
- # select candidates iou >= threshold as positive
- is_pos = torch.where(
- iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]),
- is_in_candidate, torch.zeros_like(is_in_candidate))
- is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes)
- mask_pos = is_pos * is_in_gts * mask_gt
- target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
- mask_pos, overlaps, self.n_max_boxes)
- # assigned target
- target_labels, target_bboxes, target_scores = self.get_targets(
- gt_labels, gt_bboxes, target_gt_idx, fg_mask)
- # soft label with iou
- if pd_bboxes is not None:
- ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_pos
- ious = ious.max(axis=-2)[0].unsqueeze(-1)
- target_scores *= ious
- return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
- def select_topk_candidates(self,
- distances,
- n_level_bboxes,
- mask_gt):
- mask_gt = mask_gt.repeat(1, 1, self.topk).bool()
- level_distances = torch.split(distances, n_level_bboxes, dim=-1)
- is_in_candidate_list = []
- candidate_idxs = []
- start_idx = 0
- for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes):
- end_idx = start_idx + per_level_boxes
- selected_k = min(self.topk, per_level_boxes)
- _, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False)
- candidate_idxs.append(per_level_topk_idxs + start_idx)
- per_level_topk_idxs = torch.where(mask_gt,
- per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs))
- is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2)
- is_in_candidate = torch.where(is_in_candidate > 1,
- torch.zeros_like(is_in_candidate), is_in_candidate)
- is_in_candidate_list.append(is_in_candidate.to(distances.dtype))
- start_idx = end_idx
- is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)
- candidate_idxs = torch.cat(candidate_idxs, dim=-1)
- return is_in_candidate_list, candidate_idxs
- def thres_calculator(self,
- is_in_candidate,
- candidate_idxs,
- overlaps):
- n_bs_max_boxes = self.bs * self.n_max_boxes
- _candidate_overlaps = torch.where(is_in_candidate > 0, overlaps, torch.zeros_like(overlaps))
- candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1])
- assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device)
- assist_idxs = assist_idxs[:,None]
- faltten_idxs = candidate_idxs + assist_idxs
- candidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs]
- candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1])
- overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True)
- overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True)
- overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
- return overlaps_thr_per_gt, _candidate_overlaps
- def get_targets(self,
- gt_labels,
- gt_bboxes,
- target_gt_idx,
- fg_mask):
- # assigned target labels
- batch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device)
- batch_idx = batch_idx[..., None]
- target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long()
- target_labels = gt_labels.flatten()[target_gt_idx.flatten()]
- target_labels = target_labels.reshape([self.bs, self.n_anchors])
- target_labels = torch.where(fg_mask > 0,
- target_labels, torch.full_like(target_labels, self.bg_idx))
- # assigned target boxes
- target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()]
- target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4])
- # assigned target scores
- target_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float()
- target_scores = target_scores[:, :, :self.num_classes]
- return target_labels, target_bboxes, target_scores
|