atss.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
  5. """
  6. Select the positive anchor center in gt.
  7. Args:
  8. xy_centers (Tensor): shape(h*w, 2)
  9. gt_bboxes (Tensor): shape(b, n_boxes, 4)
  10. Returns:
  11. (Tensor): shape(b, n_boxes, h*w)
  12. """
  13. n_anchors = xy_centers.shape[0]
  14. bs, n_boxes, _ = gt_bboxes.shape
  15. lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
  16. bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
  17. # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype)
  18. return bbox_deltas.amin(3).gt_(eps)
  19. def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
  20. """
  21. If an anchor box is assigned to multiple gts, the one with the highest IoI will be selected.
  22. Args:
  23. mask_pos (Tensor): shape(b, n_max_boxes, h*w)
  24. overlaps (Tensor): shape(b, n_max_boxes, h*w)
  25. Returns:
  26. target_gt_idx (Tensor): shape(b, h*w)
  27. fg_mask (Tensor): shape(b, h*w)
  28. mask_pos (Tensor): shape(b, n_max_boxes, h*w)
  29. """
  30. # (b, n_max_boxes, h*w) -> (b, h*w)
  31. fg_mask = mask_pos.sum(-2)
  32. if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
  33. mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
  34. max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
  35. is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
  36. is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
  37. mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
  38. fg_mask = mask_pos.sum(-2)
  39. # Find each grid serve which gt(index)
  40. target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
  41. return target_gt_idx, fg_mask, mask_pos
  42. def generate_anchors(feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5, device='cpu', is_eval=False, mode='af'):
  43. '''Generate anchors from features.'''
  44. anchors = []
  45. anchor_points = []
  46. stride_tensor = []
  47. num_anchors_list = []
  48. assert feats is not None
  49. if is_eval:
  50. for i, stride in enumerate(fpn_strides):
  51. _, _, h, w = feats[i].shape
  52. shift_x = torch.arange(end=w, device=device) + grid_cell_offset
  53. shift_y = torch.arange(end=h, device=device) + grid_cell_offset
  54. shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
  55. anchor_point = torch.stack(
  56. [shift_x, shift_y], axis=-1).to(torch.float)
  57. if mode == 'af': # anchor-free
  58. anchor_points.append(anchor_point.reshape([-1, 2]))
  59. stride_tensor.append(
  60. torch.full(
  61. (h * w, 1), stride, dtype=torch.float, device=device))
  62. elif mode == 'ab': # anchor-based
  63. anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1))
  64. stride_tensor.append(
  65. torch.full(
  66. (h * w, 1), stride, dtype=torch.float, device=device).repeat(3,1))
  67. anchor_points = torch.cat(anchor_points)
  68. stride_tensor = torch.cat(stride_tensor)
  69. return anchor_points, stride_tensor
  70. else:
  71. for i, stride in enumerate(fpn_strides):
  72. _, _, h, w = feats[i].shape
  73. cell_half_size = grid_cell_size * stride * 0.5
  74. shift_x = (torch.arange(end=w, device=device) + grid_cell_offset) * stride
  75. shift_y = (torch.arange(end=h, device=device) + grid_cell_offset) * stride
  76. shift_y, shift_x = torch.meshgrid(shift_y, shift_x, indexing='ij')
  77. anchor = torch.stack(
  78. [
  79. shift_x - cell_half_size, shift_y - cell_half_size,
  80. shift_x + cell_half_size, shift_y + cell_half_size
  81. ],
  82. axis=-1).clone().to(feats[0].dtype)
  83. anchor_point = torch.stack(
  84. [shift_x, shift_y], axis=-1).clone().to(feats[0].dtype)
  85. if mode == 'af': # anchor-free
  86. anchors.append(anchor.reshape([-1, 4]))
  87. anchor_points.append(anchor_point.reshape([-1, 2]))
  88. elif mode == 'ab': # anchor-based
  89. anchors.append(anchor.reshape([-1, 4]).repeat(3,1))
  90. anchor_points.append(anchor_point.reshape([-1, 2]).repeat(3,1))
  91. num_anchors_list.append(len(anchors[-1]))
  92. stride_tensor.append(
  93. torch.full(
  94. [num_anchors_list[-1], 1], stride, dtype=feats[0].dtype))
  95. anchors = torch.cat(anchors)
  96. anchor_points = torch.cat(anchor_points).to(device)
  97. stride_tensor = torch.cat(stride_tensor).to(device)
  98. return anchors, anchor_points, num_anchors_list, stride_tensor
  99. def fp16_clamp(x, min=None, max=None):
  100. if not x.is_cuda and x.dtype == torch.float16:
  101. # clamp for cpu float16, tensor fp16 has no clamp implementation
  102. return x.float().clamp(min, max).half()
  103. return x.clamp(min, max)
  104. def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
  105. """Calculate overlap between two set of bboxes.
  106. FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889
  107. Note:
  108. Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou',
  109. there are some new generated variable when calculating IOU
  110. using bbox_overlaps function:
  111. 1) is_aligned is False
  112. area1: M x 1
  113. area2: N x 1
  114. lt: M x N x 2
  115. rb: M x N x 2
  116. wh: M x N x 2
  117. overlap: M x N x 1
  118. union: M x N x 1
  119. ious: M x N x 1
  120. Total memory:
  121. S = (9 x N x M + N + M) * 4 Byte,
  122. When using FP16, we can reduce:
  123. R = (9 x N x M + N + M) * 4 / 2 Byte
  124. R large than (N + M) * 4 * 2 is always true when N and M >= 1.
  125. Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2,
  126. N + 1 < 3 * N, when N or M is 1.
  127. Given M = 40 (ground truth), N = 400000 (three anchor boxes
  128. in per grid, FPN, R-CNNs),
  129. R = 275 MB (one times)
  130. A special case (dense detection), M = 512 (ground truth),
  131. R = 3516 MB = 3.43 GB
  132. When the batch size is B, reduce:
  133. B x R
  134. Therefore, CUDA memory runs out frequently.
  135. Experiments on GeForce RTX 2080Ti (11019 MiB):
  136. | dtype | M | N | Use | Real | Ideal |
  137. |:----:|:----:|:----:|:----:|:----:|:----:|
  138. | FP32 | 512 | 400000 | 8020 MiB | -- | -- |
  139. | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
  140. | FP32 | 40 | 400000 | 1540 MiB | -- | -- |
  141. | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB |
  142. 2) is_aligned is True
  143. area1: N x 1
  144. area2: N x 1
  145. lt: N x 2
  146. rb: N x 2
  147. wh: N x 2
  148. overlap: N x 1
  149. union: N x 1
  150. ious: N x 1
  151. Total memory:
  152. S = 11 x N * 4 Byte
  153. When using FP16, we can reduce:
  154. R = 11 x N * 4 / 2 Byte
  155. So do the 'giou' (large than 'iou').
  156. Time-wise, FP16 is generally faster than FP32.
  157. When gpu_assign_thr is not -1, it takes more time on cpu
  158. but not reduce memory.
  159. There, we can reduce half the memory and keep the speed.
  160. If ``is_aligned`` is ``False``, then calculate the overlaps between each
  161. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  162. pair of bboxes1 and bboxes2.
  163. Args:
  164. bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  165. bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  166. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  167. If ``is_aligned`` is ``True``, then m and n must be equal.
  168. mode (str): "iou" (intersection over union), "iof" (intersection over
  169. foreground) or "giou" (generalized intersection over union).
  170. Default "iou".
  171. is_aligned (bool, optional): If True, then m and n must be equal.
  172. Default False.
  173. eps (float, optional): A value added to the denominator for numerical
  174. stability. Default 1e-6.
  175. Returns:
  176. Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
  177. Example:
  178. >>> bboxes1 = torch.FloatTensor([
  179. >>> [0, 0, 10, 10],
  180. >>> [10, 10, 20, 20],
  181. >>> [32, 32, 38, 42],
  182. >>> ])
  183. >>> bboxes2 = torch.FloatTensor([
  184. >>> [0, 0, 10, 20],
  185. >>> [0, 10, 10, 19],
  186. >>> [10, 10, 20, 20],
  187. >>> ])
  188. >>> overlaps = bbox_overlaps(bboxes1, bboxes2)
  189. >>> assert overlaps.shape == (3, 3)
  190. >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
  191. >>> assert overlaps.shape == (3, )
  192. Example:
  193. >>> empty = torch.empty(0, 4)
  194. >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
  195. >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
  196. >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
  197. >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
  198. """
  199. assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
  200. # Either the boxes are empty or the length of boxes' last dimension is 4
  201. assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
  202. assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
  203. # Batch dim must be the same
  204. # Batch dim: (B1, B2, ... Bn)
  205. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  206. batch_shape = bboxes1.shape[:-2]
  207. rows = bboxes1.size(-2)
  208. cols = bboxes2.size(-2)
  209. if is_aligned:
  210. assert rows == cols
  211. if rows * cols == 0:
  212. if is_aligned:
  213. return bboxes1.new(batch_shape + (rows, ))
  214. else:
  215. return bboxes1.new(batch_shape + (rows, cols))
  216. area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
  217. bboxes1[..., 3] - bboxes1[..., 1])
  218. area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
  219. bboxes2[..., 3] - bboxes2[..., 1])
  220. if is_aligned:
  221. lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
  222. rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
  223. wh = fp16_clamp(rb - lt, min=0)
  224. overlap = wh[..., 0] * wh[..., 1]
  225. if mode in ['iou', 'giou']:
  226. union = area1 + area2 - overlap
  227. else:
  228. union = area1
  229. if mode == 'giou':
  230. enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
  231. enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
  232. else:
  233. lt = torch.max(bboxes1[..., :, None, :2],
  234. bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
  235. rb = torch.min(bboxes1[..., :, None, 2:],
  236. bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
  237. wh = fp16_clamp(rb - lt, min=0)
  238. overlap = wh[..., 0] * wh[..., 1]
  239. if mode in ['iou', 'giou']:
  240. union = area1[..., None] + area2[..., None, :] - overlap
  241. else:
  242. union = area1[..., None]
  243. if mode == 'giou':
  244. enclosed_lt = torch.min(bboxes1[..., :, None, :2],
  245. bboxes2[..., None, :, :2])
  246. enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
  247. bboxes2[..., None, :, 2:])
  248. eps = union.new_tensor([eps])
  249. union = torch.max(union, eps)
  250. ious = overlap / union
  251. if mode in ['iou', 'iof']:
  252. return ious
  253. # calculate gious
  254. enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0)
  255. enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
  256. enclose_area = torch.max(enclose_area, eps)
  257. gious = ious - (enclose_area - union) / enclose_area
  258. return gious
  259. def cast_tensor_type(x, scale=1., dtype=None):
  260. if dtype == 'fp16':
  261. # scale is for preventing overflows
  262. x = (x / scale).half()
  263. return x
  264. def iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None):
  265. """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
  266. """Calculate IoU between 2D bboxes.
  267. Args:
  268. bboxes1 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>
  269. format, or shape (m, 5) in <x1, y1, x2, y2, score> format.
  270. bboxes2 (Tensor): bboxes have shape (m, 4) in <x1, y1, x2, y2>
  271. format, shape (m, 5) in <x1, y1, x2, y2, score> format, or be
  272. empty. If ``is_aligned `` is ``True``, then m and n must be
  273. equal.
  274. mode (str): "iou" (intersection over union), "iof" (intersection
  275. over foreground), or "giou" (generalized intersection over
  276. union).
  277. is_aligned (bool, optional): If True, then m and n must be equal.
  278. Default False.
  279. Returns:
  280. Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
  281. """
  282. assert bboxes1.size(-1) in [0, 4, 5]
  283. assert bboxes2.size(-1) in [0, 4, 5]
  284. if bboxes2.size(-1) == 5:
  285. bboxes2 = bboxes2[..., :4]
  286. if bboxes1.size(-1) == 5:
  287. bboxes1 = bboxes1[..., :4]
  288. if dtype == 'fp16':
  289. # change tensor type to save cpu and cuda memory and keep speed
  290. bboxes1 = cast_tensor_type(bboxes1, scale, dtype)
  291. bboxes2 = cast_tensor_type(bboxes2, scale, dtype)
  292. overlaps = bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
  293. if not overlaps.is_cuda and overlaps.dtype == torch.float16:
  294. # resume cpu float32
  295. overlaps = overlaps.float()
  296. return overlaps
  297. return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
  298. def dist_calculator(gt_bboxes, anchor_bboxes):
  299. """compute center distance between all bbox and gt
  300. Args:
  301. gt_bboxes (Tensor): shape(bs*n_max_boxes, 4)
  302. anchor_bboxes (Tensor): shape(num_total_anchors, 4)
  303. Return:
  304. distances (Tensor): shape(bs*n_max_boxes, num_total_anchors)
  305. ac_points (Tensor): shape(num_total_anchors, 2)
  306. """
  307. gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
  308. gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
  309. gt_points = torch.stack([gt_cx, gt_cy], dim=1)
  310. ac_cx = (anchor_bboxes[:, 0] + anchor_bboxes[:, 2]) / 2.0
  311. ac_cy = (anchor_bboxes[:, 1] + anchor_bboxes[:, 3]) / 2.0
  312. ac_points = torch.stack([ac_cx, ac_cy], dim=1)
  313. distances = (gt_points[:, None, :] - ac_points[None, :, :]).pow(2).sum(-1).sqrt()
  314. return distances, ac_points
  315. def iou_calculator(box1, box2, eps=1e-9):
  316. """Calculate iou for batch
  317. Args:
  318. box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
  319. box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
  320. Return:
  321. (Tensor): shape(bs, n_max_boxes, num_total_anchors)
  322. """
  323. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  324. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  325. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  326. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  327. x1y1 = torch.maximum(px1y1, gx1y1)
  328. x2y2 = torch.minimum(px2y2, gx2y2)
  329. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  330. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  331. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  332. union = area1 + area2 - overlap + eps
  333. return overlap / union
  334. class ATSSAssigner(nn.Module):
  335. '''Adaptive Training Sample Selection Assigner'''
  336. def __init__(self,
  337. topk=9,
  338. num_classes=80):
  339. super(ATSSAssigner, self).__init__()
  340. self.topk = topk
  341. self.num_classes = num_classes
  342. self.bg_idx = num_classes
  343. @torch.no_grad()
  344. def forward(self,
  345. anc_bboxes,
  346. n_level_bboxes,
  347. gt_labels,
  348. gt_bboxes,
  349. mask_gt,
  350. pd_bboxes):
  351. r"""This code is based on
  352. https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
  353. Args:
  354. anc_bboxes (Tensor): shape(num_total_anchors, 4)
  355. n_level_bboxes (List):len(3)
  356. gt_labels (Tensor): shape(bs, n_max_boxes, 1)
  357. gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
  358. mask_gt (Tensor): shape(bs, n_max_boxes, 1)
  359. pd_bboxes (Tensor): shape(bs, n_max_boxes, 4)
  360. Returns:
  361. target_labels (Tensor): shape(bs, num_total_anchors)
  362. target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
  363. target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
  364. fg_mask (Tensor): shape(bs, num_total_anchors)
  365. """
  366. self.n_anchors = anc_bboxes.size(0)
  367. self.bs = gt_bboxes.size(0)
  368. self.n_max_boxes = gt_bboxes.size(1)
  369. if self.n_max_boxes == 0:
  370. device = gt_bboxes.device
  371. return torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \
  372. torch.zeros([self.bs, self.n_anchors, 4]).to(device), \
  373. torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \
  374. torch.zeros([self.bs, self.n_anchors]).to(device), \
  375. torch.zeros([self.bs, self.n_anchors]).to(device)
  376. overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
  377. overlaps = overlaps.reshape([self.bs, -1, self.n_anchors])
  378. distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
  379. distances = distances.reshape([self.bs, -1, self.n_anchors])
  380. is_in_candidate, candidate_idxs = self.select_topk_candidates(
  381. distances, n_level_bboxes, mask_gt)
  382. overlaps_thr_per_gt, iou_candidates = self.thres_calculator(
  383. is_in_candidate, candidate_idxs, overlaps)
  384. # select candidates iou >= threshold as positive
  385. is_pos = torch.where(
  386. iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]),
  387. is_in_candidate, torch.zeros_like(is_in_candidate))
  388. is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes)
  389. mask_pos = is_pos * is_in_gts * mask_gt
  390. target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
  391. mask_pos, overlaps, self.n_max_boxes)
  392. # assigned target
  393. target_labels, target_bboxes, target_scores = self.get_targets(
  394. gt_labels, gt_bboxes, target_gt_idx, fg_mask)
  395. # soft label with iou
  396. if pd_bboxes is not None:
  397. ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_pos
  398. ious = ious.max(axis=-2)[0].unsqueeze(-1)
  399. target_scores *= ious
  400. return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
  401. def select_topk_candidates(self,
  402. distances,
  403. n_level_bboxes,
  404. mask_gt):
  405. mask_gt = mask_gt.repeat(1, 1, self.topk).bool()
  406. level_distances = torch.split(distances, n_level_bboxes, dim=-1)
  407. is_in_candidate_list = []
  408. candidate_idxs = []
  409. start_idx = 0
  410. for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes):
  411. end_idx = start_idx + per_level_boxes
  412. selected_k = min(self.topk, per_level_boxes)
  413. _, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False)
  414. candidate_idxs.append(per_level_topk_idxs + start_idx)
  415. per_level_topk_idxs = torch.where(mask_gt,
  416. per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs))
  417. is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2)
  418. is_in_candidate = torch.where(is_in_candidate > 1,
  419. torch.zeros_like(is_in_candidate), is_in_candidate)
  420. is_in_candidate_list.append(is_in_candidate.to(distances.dtype))
  421. start_idx = end_idx
  422. is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)
  423. candidate_idxs = torch.cat(candidate_idxs, dim=-1)
  424. return is_in_candidate_list, candidate_idxs
  425. def thres_calculator(self,
  426. is_in_candidate,
  427. candidate_idxs,
  428. overlaps):
  429. n_bs_max_boxes = self.bs * self.n_max_boxes
  430. _candidate_overlaps = torch.where(is_in_candidate > 0, overlaps, torch.zeros_like(overlaps))
  431. candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1])
  432. assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device)
  433. assist_idxs = assist_idxs[:,None]
  434. faltten_idxs = candidate_idxs + assist_idxs
  435. candidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs]
  436. candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1])
  437. overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True)
  438. overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True)
  439. overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
  440. return overlaps_thr_per_gt, _candidate_overlaps
  441. def get_targets(self,
  442. gt_labels,
  443. gt_bboxes,
  444. target_gt_idx,
  445. fg_mask):
  446. # assigned target labels
  447. batch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device)
  448. batch_idx = batch_idx[..., None]
  449. target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long()
  450. target_labels = gt_labels.flatten()[target_gt_idx.flatten()]
  451. target_labels = target_labels.reshape([self.bs, self.n_anchors])
  452. target_labels = torch.where(fg_mask > 0,
  453. target_labels, torch.full_like(target_labels, self.bg_idx))
  454. # assigned target boxes
  455. target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()]
  456. target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4])
  457. # assigned target scores
  458. target_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float()
  459. target_scores = target_scores[:, :, :self.num_classes]
  460. return target_labels, target_bboxes, target_scores