loss.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ultralytics.utils.metrics import OKS_SIGMA
  6. from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
  7. from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
  8. from ultralytics.utils.atss import ATSSAssigner, generate_anchors
  9. from .metrics import bbox_iou, probiou, bbox_mpdiou, bbox_inner_iou, bbox_focaler_iou, bbox_inner_mpdiou, bbox_focaler_mpdiou, wasserstein_loss, WiseIouLoss
  10. from .tal import bbox2dist
  11. import math
  12. class SlideLoss(nn.Module):
  13. def __init__(self, loss_fcn):
  14. super(SlideLoss, self).__init__()
  15. self.loss_fcn = loss_fcn
  16. self.reduction = loss_fcn.reduction
  17. self.loss_fcn.reduction = 'none' # required to apply SL to each element
  18. def forward(self, pred, true, auto_iou=0.5):
  19. loss = self.loss_fcn(pred, true)
  20. if auto_iou < 0.2:
  21. auto_iou = 0.2
  22. b1 = true <= auto_iou - 0.1
  23. a1 = 1.0
  24. b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)
  25. a2 = math.exp(1.0 - auto_iou)
  26. b3 = true >= auto_iou
  27. a3 = torch.exp(-(true - 1.0))
  28. modulating_weight = a1 * b1 + a2 * b2 + a3 * b3
  29. loss *= modulating_weight
  30. if self.reduction == 'mean':
  31. return loss.mean()
  32. elif self.reduction == 'sum':
  33. return loss.sum()
  34. else: # 'none'
  35. return loss
  36. class EMASlideLoss:
  37. def __init__(self, loss_fcn, decay=0.999, tau=2000):
  38. super(EMASlideLoss, self).__init__()
  39. self.loss_fcn = loss_fcn
  40. self.reduction = loss_fcn.reduction
  41. self.loss_fcn.reduction = 'none' # required to apply SL to each element
  42. self.decay = lambda x: decay * (1 - math.exp(-x / tau))
  43. self.is_train = True
  44. self.updates = 0
  45. self.iou_mean = 1.0
  46. def __call__(self, pred, true, auto_iou=0.5):
  47. if self.is_train and auto_iou != -1:
  48. self.updates += 1
  49. d = self.decay(self.updates)
  50. self.iou_mean = d * self.iou_mean + (1 - d) * float(auto_iou.detach())
  51. auto_iou = self.iou_mean
  52. loss = self.loss_fcn(pred, true)
  53. if auto_iou < 0.2:
  54. auto_iou = 0.2
  55. b1 = true <= auto_iou - 0.1
  56. a1 = 1.0
  57. b2 = (true > (auto_iou - 0.1)) & (true < auto_iou)
  58. a2 = math.exp(1.0 - auto_iou)
  59. b3 = true >= auto_iou
  60. a3 = torch.exp(-(true - 1.0))
  61. modulating_weight = a1 * b1 + a2 * b2 + a3 * b3
  62. loss *= modulating_weight
  63. if self.reduction == 'mean':
  64. return loss.mean()
  65. elif self.reduction == 'sum':
  66. return loss.sum()
  67. else: # 'none'
  68. return loss
  69. class VarifocalLoss(nn.Module):
  70. """
  71. Varifocal loss by Zhang et al.
  72. https://arxiv.org/abs/2008.13367.
  73. """
  74. def __init__(self):
  75. """Initialize the VarifocalLoss class."""
  76. super().__init__()
  77. @staticmethod
  78. def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
  79. """Computes varfocal loss."""
  80. weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
  81. with torch.cuda.amp.autocast(enabled=False):
  82. loss = (
  83. (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
  84. .mean(1)
  85. .sum()
  86. )
  87. return loss
  88. class FocalLoss(nn.Module):
  89. """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
  90. def __init__(self):
  91. """Initializer for FocalLoss class with no parameters."""
  92. super().__init__()
  93. @staticmethod
  94. def forward(pred, label, gamma=1.5, alpha=0.25):
  95. """Calculates and updates confusion matrix for object detection/classification tasks."""
  96. loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
  97. # p_t = torch.exp(-loss)
  98. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  99. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  100. pred_prob = pred.sigmoid() # prob from logits
  101. p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
  102. modulating_factor = (1.0 - p_t) ** gamma
  103. loss *= modulating_factor
  104. if alpha > 0:
  105. alpha_factor = label * alpha + (1 - label) * (1 - alpha)
  106. loss *= alpha_factor
  107. return loss.mean(1).sum()
  108. class VarifocalLoss_YOLO(nn.Module):
  109. """
  110. Varifocal loss by Zhang et al.
  111. https://arxiv.org/abs/2008.13367.
  112. """
  113. def __init__(self, alpha=0.75, gamma=2.0):
  114. """Initialize the VarifocalLoss class."""
  115. super().__init__()
  116. self.alpha = alpha
  117. self.gamma = gamma
  118. def forward(self, pred_score, gt_score):
  119. """Computes varfocal loss."""
  120. weight = self.alpha * (pred_score.sigmoid() - gt_score).abs().pow(self.gamma) * (gt_score <= 0.0).float() + gt_score * (gt_score > 0.0).float()
  121. with torch.cuda.amp.autocast(enabled=False):
  122. return F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * weight
  123. class QualityfocalLoss_YOLO(nn.Module):
  124. def __init__(self, beta=2.0):
  125. super().__init__()
  126. self.beta = beta
  127. def forward(self, pred_score, gt_score, gt_target_pos_mask):
  128. # negatives are supervised by 0 quality score
  129. pred_sigmoid = pred_score.sigmoid()
  130. scale_factor = pred_sigmoid
  131. zerolabel = scale_factor.new_zeros(pred_score.shape)
  132. with torch.cuda.amp.autocast(enabled=False):
  133. loss = F.binary_cross_entropy_with_logits(pred_score, zerolabel, reduction='none') * scale_factor.pow(self.beta)
  134. scale_factor = gt_score[gt_target_pos_mask] - pred_sigmoid[gt_target_pos_mask]
  135. with torch.cuda.amp.autocast(enabled=False):
  136. loss[gt_target_pos_mask] = F.binary_cross_entropy_with_logits(pred_score[gt_target_pos_mask], gt_score[gt_target_pos_mask], reduction='none') * scale_factor.abs().pow(self.beta)
  137. return loss
  138. class FocalLoss_YOLO(nn.Module):
  139. """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
  140. def __init__(self, gamma=1.5, alpha=0.25):
  141. """Initializer for FocalLoss class with no parameters."""
  142. super().__init__()
  143. self.gamma = gamma
  144. self.alpha = alpha
  145. def forward(self, pred, label):
  146. """Calculates and updates confusion matrix for object detection/classification tasks."""
  147. loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
  148. # p_t = torch.exp(-loss)
  149. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  150. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  151. pred_prob = pred.sigmoid() # prob from logits
  152. p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
  153. modulating_factor = (1.0 - p_t) ** self.gamma
  154. loss *= modulating_factor
  155. if self.alpha > 0:
  156. alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
  157. loss *= alpha_factor
  158. return loss
  159. class DFLoss(nn.Module):
  160. """Criterion class for computing DFL losses during training."""
  161. def __init__(self, reg_max=16) -> None:
  162. """Initialize the DFL module."""
  163. super().__init__()
  164. self.reg_max = reg_max
  165. def __call__(self, pred_dist, target):
  166. """
  167. Return sum of left and right DFL losses.
  168. Distribution Focal Loss (DFL) proposed in Generalized Focal Loss
  169. https://ieeexplore.ieee.org/document/9792391
  170. """
  171. target = target.clamp_(0, self.reg_max - 1 - 0.01)
  172. tl = target.long() # target left
  173. tr = tl + 1 # target right
  174. wl = tr - target # weight left
  175. wr = 1 - wl # weight right
  176. return (
  177. F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
  178. + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
  179. ).mean(-1, keepdim=True)
  180. class BboxLoss(nn.Module):
  181. """Criterion class for computing training losses during training."""
  182. def __init__(self, reg_max=16):
  183. """Initialize the BboxLoss module with regularization maximum and DFL settings."""
  184. super().__init__()
  185. self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
  186. # NWD
  187. self.nwd_loss = False
  188. self.iou_ratio = 0.5 # total_iou_loss = self.iou_ratio * iou_loss + (1 - self.iou_ratio) * nwd_loss
  189. # WiseIOU
  190. self.use_wiseiou = False
  191. if self.use_wiseiou:
  192. self.wiou_loss = WiseIouLoss(ltype='WIoU', monotonous=False, inner_iou=False, focaler_iou=False)
  193. def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask, mpdiou_hw=None):
  194. """IoU loss."""
  195. weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
  196. if self.use_wiseiou:
  197. wiou = self.wiou_loss(pred_bboxes[fg_mask], target_bboxes[fg_mask], ret_iou=False, ratio=0.7, d=0.0, u=0.95).unsqueeze(-1)
  198. # wiou = self.wiou_loss(pred_bboxes[fg_mask], target_bboxes[fg_mask], ret_iou=False, ratio=0.7, d=0.0, u=0.95, **{'scale':0.0}).unsqueeze(-1) # Wise-ShapeIoU,Wise-Inner-ShapeIoU,Wise-Focaler-ShapeIoU
  199. # wiou = self.wiou_loss(pred_bboxes[fg_mask], target_bboxes[fg_mask], ret_iou=False, ratio=0.7, d=0.0, u=0.95, **{'mpdiou_hw':mpdiou_hw[fg_mask]}).unsqueeze(-1) # Wise-MPDIoU,Wise-Inner-MPDIoU,Wise-Focaler-MPDIoU
  200. loss_iou = (wiou * weight).sum() / target_scores_sum
  201. else:
  202. iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
  203. # iou = bbox_inner_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True, ratio=0.7)
  204. # iou = bbox_mpdiou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, mpdiou_hw=mpdiou_hw[fg_mask])
  205. # iou = bbox_inner_mpdiou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, mpdiou_hw=mpdiou_hw[fg_mask], ratio=0.7)
  206. # iou = bbox_focaler_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True, d=0.0, u=0.95)
  207. # iou = bbox_focaler_mpdiou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, mpdiou_hw=mpdiou_hw[fg_mask], d=0.0, u=0.95)
  208. loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
  209. if self.nwd_loss:
  210. nwd = wasserstein_loss(pred_bboxes[fg_mask], target_bboxes[fg_mask])
  211. nwd_loss = ((1.0 - nwd) * weight).sum() / target_scores_sum
  212. loss_iou = self.iou_ratio * loss_iou + (1 - self.iou_ratio) * nwd_loss
  213. # DFL loss
  214. if self.dfl_loss:
  215. target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
  216. loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
  217. loss_dfl = loss_dfl.sum() / target_scores_sum
  218. else:
  219. loss_dfl = torch.tensor(0.0).to(pred_dist.device)
  220. return loss_iou, loss_dfl
  221. class RotatedBboxLoss(BboxLoss):
  222. """Criterion class for computing training losses during training."""
  223. def __init__(self, reg_max):
  224. """Initialize the BboxLoss module with regularization maximum and DFL settings."""
  225. super().__init__(reg_max)
  226. def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
  227. """IoU loss."""
  228. weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
  229. iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
  230. loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
  231. # DFL loss
  232. if self.dfl_loss:
  233. target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
  234. loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
  235. loss_dfl = loss_dfl.sum() / target_scores_sum
  236. else:
  237. loss_dfl = torch.tensor(0.0).to(pred_dist.device)
  238. return loss_iou, loss_dfl
  239. class KeypointLoss(nn.Module):
  240. """Criterion class for computing training losses."""
  241. def __init__(self, sigmas) -> None:
  242. """Initialize the KeypointLoss class."""
  243. super().__init__()
  244. self.sigmas = sigmas
  245. def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
  246. """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
  247. d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
  248. kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
  249. # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
  250. e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
  251. return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
  252. class v8DetectionLoss:
  253. """Criterion class for computing training losses."""
  254. def __init__(self, model, tal_topk=10): # model must be de-paralleled
  255. """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function."""
  256. device = next(model.parameters()).device # get model device
  257. h = model.args # hyperparameters
  258. m = model.model[-1] # Detect() module
  259. self.bce = nn.BCEWithLogitsLoss(reduction="none")
  260. # self.bce = EMASlideLoss(nn.BCEWithLogitsLoss(reduction='none')) # Exponential Moving Average Slide Loss
  261. # self.bce = SlideLoss(nn.BCEWithLogitsLoss(reduction='none')) # Slide Loss
  262. # self.bce = FocalLoss_YOLO(alpha=0.25, gamma=1.5) # FocalLoss
  263. # self.bce = VarifocalLoss_YOLO(alpha=0.75, gamma=2.0) # VarifocalLoss
  264. # self.bce = QualityfocalLoss_YOLO(beta=2.0) # QualityfocalLoss
  265. self.hyp = h
  266. self.stride = m.stride # model strides
  267. self.nc = m.nc # number of classes
  268. self.no = m.nc + m.reg_max * 4
  269. self.reg_max = m.reg_max
  270. self.device = device
  271. self.use_dfl = m.reg_max > 1
  272. self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
  273. if hasattr(m, 'dfl_aux'):
  274. self.assigner_aux = TaskAlignedAssigner(topk=13, num_classes=self.nc, alpha=0.5, beta=6.0)
  275. self.aux_loss_ratio = 0.25
  276. # self.assigner = ATSSAssigner(9, num_classes=self.nc)
  277. self.bbox_loss = BboxLoss(m.reg_max).to(device)
  278. self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
  279. # ATSS use
  280. self.grid_cell_offset = 0.5
  281. self.fpn_strides = list(self.stride.detach().cpu().numpy())
  282. self.grid_cell_size = 5.0
  283. def preprocess(self, targets, batch_size, scale_tensor):
  284. """Preprocesses the target counts and matches with the input batch size to output a tensor."""
  285. nl, ne = targets.shape
  286. if nl == 0:
  287. out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
  288. else:
  289. i = targets[:, 0] # image index
  290. _, counts = i.unique(return_counts=True)
  291. counts = counts.to(dtype=torch.int32)
  292. out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
  293. for j in range(batch_size):
  294. matches = i == j
  295. n = matches.sum()
  296. if n:
  297. out[j, :n] = targets[matches, 1:]
  298. out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
  299. return out
  300. def bbox_decode(self, anchor_points, pred_dist):
  301. """Decode predicted object bounding box coordinates from anchor points and distribution."""
  302. if self.use_dfl:
  303. b, a, c = pred_dist.shape # batch, anchors, channels
  304. pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
  305. # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
  306. # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
  307. return dist2bbox(pred_dist, anchor_points, xywh=False)
  308. def __call__(self, preds, batch):
  309. if hasattr(self, 'assigner_aux'):
  310. loss, batch_size = self.compute_loss_aux(preds, batch)
  311. else:
  312. loss, batch_size = self.compute_loss(preds, batch)
  313. return loss.sum() * batch_size, loss.detach()
  314. def compute_loss(self, preds, batch):
  315. """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
  316. loss = torch.zeros(3, device=self.device) # box, cls, dfl
  317. feats = preds[1] if isinstance(preds, tuple) else preds
  318. feats = feats[:self.stride.size(0)]
  319. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
  320. (self.reg_max * 4, self.nc), 1)
  321. pred_scores = pred_scores.permute(0, 2, 1).contiguous()
  322. pred_distri = pred_distri.permute(0, 2, 1).contiguous()
  323. dtype = pred_scores.dtype
  324. batch_size = pred_scores.shape[0]
  325. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  326. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  327. # targets
  328. targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
  329. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  330. gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
  331. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
  332. # pboxes
  333. pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
  334. # ATSS
  335. if isinstance(self.assigner, ATSSAssigner):
  336. anchors, _, n_anchors_list, _ = \
  337. generate_anchors(feats, self.fpn_strides, self.grid_cell_size, self.grid_cell_offset, device=feats[0].device)
  338. target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(anchors, n_anchors_list, gt_labels, gt_bboxes, mask_gt, pred_bboxes.detach() * stride_tensor)
  339. # TAL
  340. else:
  341. target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
  342. pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  343. anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
  344. target_scores_sum = max(target_scores.sum(), 1)
  345. # cls loss
  346. if isinstance(self.bce, (nn.BCEWithLogitsLoss, FocalLoss_YOLO)):
  347. loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
  348. elif isinstance(self.bce, VarifocalLoss_YOLO):
  349. if fg_mask.sum():
  350. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  351. # 10.0x Faster than torch.one_hot
  352. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  353. dtype=torch.int64,
  354. device=target_labels.device) # (b, h*w, 80)
  355. cls_iou_targets.scatter_(2, target_labels.unsqueeze(-1), 1)
  356. cls_iou_targets = pos_ious * cls_iou_targets
  357. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  358. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  359. else:
  360. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  361. dtype=torch.int64,
  362. device=target_labels.device) # (b, h*w, 80)
  363. loss[1] = self.bce(pred_scores, cls_iou_targets.to(dtype)).sum() / max(fg_mask.sum(), 1) # BCE
  364. elif isinstance(self.bce, QualityfocalLoss_YOLO):
  365. if fg_mask.sum():
  366. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  367. # 10.0x Faster than torch.one_hot
  368. targets_onehot = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  369. dtype=torch.int64,
  370. device=target_labels.device) # (b, h*w, 80)
  371. targets_onehot.scatter_(2, target_labels.unsqueeze(-1), 1)
  372. cls_iou_targets = pos_ious * targets_onehot
  373. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  374. targets_onehot_pos = torch.where(fg_scores_mask > 0, targets_onehot, 0)
  375. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  376. else:
  377. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  378. dtype=torch.int64,
  379. device=target_labels.device) # (b, h*w, 80)
  380. targets_onehot_pos = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  381. dtype=torch.int64,
  382. device=target_labels.device) # (b, h*w, 80)
  383. loss[1] = self.bce(pred_scores, cls_iou_targets.to(dtype), targets_onehot_pos.to(torch.bool)).sum() / max(fg_mask.sum(), 1)
  384. # bbox loss
  385. if fg_mask.sum():
  386. target_bboxes /= stride_tensor
  387. loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
  388. target_scores_sum, fg_mask, ((imgsz[0] ** 2 + imgsz[1] ** 2) / torch.square(stride_tensor)).repeat(1, batch_size).transpose(1, 0))
  389. if isinstance(self.bce, (EMASlideLoss, SlideLoss)):
  390. if fg_mask.sum():
  391. auto_iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True).mean()
  392. else:
  393. auto_iou = -1
  394. loss[1] = self.bce(pred_scores, target_scores.to(dtype), auto_iou).sum() / target_scores_sum # BCE
  395. loss[0] *= self.hyp.box # box gain
  396. loss[1] *= self.hyp.cls # cls gain
  397. loss[2] *= self.hyp.dfl # dfl gain
  398. return loss, batch_size
  399. def compute_loss_aux(self, preds, batch):
  400. """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
  401. loss = torch.zeros(3, device=self.device) # box, cls, dfl
  402. feats_all = preds[1] if isinstance(preds, tuple) else preds
  403. if len(feats_all) == self.stride.size(0):
  404. return self.compute_loss(preds, batch)
  405. feats, feats_aux = feats_all[:self.stride.size(0)], feats_all[self.stride.size(0):]
  406. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split((self.reg_max * 4, self.nc), 1)
  407. pred_distri_aux, pred_scores_aux = torch.cat([xi.view(feats_aux[0].shape[0], self.no, -1) for xi in feats_aux], 2).split((self.reg_max * 4, self.nc), 1)
  408. pred_scores, pred_distri = pred_scores.permute(0, 2, 1).contiguous(), pred_distri.permute(0, 2, 1).contiguous()
  409. pred_scores_aux, pred_distri_aux = pred_scores_aux.permute(0, 2, 1).contiguous(), pred_distri_aux.permute(0, 2, 1).contiguous()
  410. dtype = pred_scores.dtype
  411. batch_size = pred_scores.shape[0]
  412. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  413. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  414. # targets
  415. targets = torch.cat((batch['batch_idx'].view(-1, 1), batch['cls'].view(-1, 1), batch['bboxes']), 1)
  416. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  417. gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
  418. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
  419. # pboxes
  420. pred_bboxes = self.bbox_decode(anchor_points, pred_distri)
  421. pred_bboxes_aux = self.bbox_decode(anchor_points, pred_distri_aux) # xyxy, (b, h*w, 4)
  422. target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  423. anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
  424. target_labels_aux, target_bboxes_aux, target_scores_aux, fg_mask_aux, _ = self.assigner_aux(pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  425. anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt)
  426. target_scores_sum = max(target_scores.sum(), 1)
  427. target_scores_sum_aux = max(target_scores_aux.sum(), 1)
  428. # cls loss
  429. if isinstance(self.bce, nn.BCEWithLogitsLoss):
  430. loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
  431. loss[1] += self.bce(pred_scores_aux, target_scores_aux.to(dtype)).sum() / target_scores_sum_aux * self.aux_loss_ratio # BCE
  432. # bbox loss
  433. if fg_mask.sum():
  434. target_bboxes /= stride_tensor
  435. target_bboxes_aux /= stride_tensor
  436. loss[0], loss[2] = self.bbox_loss(pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores,
  437. target_scores_sum, fg_mask, ((imgsz[0] ** 2 + imgsz[1] ** 2) / torch.square(stride_tensor)).repeat(1, batch_size).transpose(1, 0))
  438. aux_loss_0, aux_loss_2 = self.bbox_loss(pred_distri_aux, pred_bboxes_aux, anchor_points, target_bboxes_aux, target_scores_aux,
  439. target_scores_sum_aux, fg_mask_aux, ((imgsz[0] ** 2 + imgsz[1] ** 2) / torch.square(stride_tensor)).repeat(1, batch_size).transpose(1, 0))
  440. loss[0] += aux_loss_0 * self.aux_loss_ratio
  441. loss[2] += aux_loss_2 * self.aux_loss_ratio
  442. if isinstance(self.bce, (EMASlideLoss, SlideLoss)):
  443. auto_iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True).mean()
  444. loss[1] = self.bce(pred_scores, target_scores.to(dtype), auto_iou).sum() / target_scores_sum # BCE
  445. loss[1] += self.bce(pred_scores_aux, target_scores_aux.to(dtype), -1).sum() / target_scores_sum_aux * self.aux_loss_ratio # BCE
  446. loss[0] *= self.hyp.box # box gain
  447. loss[1] *= self.hyp.cls # cls gain
  448. loss[2] *= self.hyp.dfl # dfl gain
  449. # return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
  450. return loss, batch_size
  451. class v8SegmentationLoss(v8DetectionLoss):
  452. """Criterion class for computing training losses."""
  453. def __init__(self, model): # model must be de-paralleled
  454. """Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument."""
  455. super().__init__(model)
  456. self.overlap = model.args.overlap_mask
  457. def __call__(self, preds, batch):
  458. """Calculate and return the loss for the YOLO model."""
  459. loss = torch.zeros(4, device=self.device) # box, cls, dfl
  460. feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
  461. batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
  462. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
  463. (self.reg_max * 4, self.nc), 1
  464. )
  465. # B, grids, ..
  466. pred_scores = pred_scores.permute(0, 2, 1).contiguous()
  467. pred_distri = pred_distri.permute(0, 2, 1).contiguous()
  468. pred_masks = pred_masks.permute(0, 2, 1).contiguous()
  469. dtype = pred_scores.dtype
  470. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  471. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  472. # Targets
  473. try:
  474. batch_idx = batch["batch_idx"].view(-1, 1)
  475. targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
  476. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  477. gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
  478. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
  479. except RuntimeError as e:
  480. raise TypeError(
  481. "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
  482. "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
  483. "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
  484. "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
  485. "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
  486. ) from e
  487. # Pboxes
  488. pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
  489. # ATSS
  490. if isinstance(self.assigner, ATSSAssigner):
  491. anchors, _, n_anchors_list, _ = \
  492. generate_anchors(feats, self.fpn_strides, self.grid_cell_size, self.grid_cell_offset, device=feats[0].device)
  493. target_labels, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(anchors, n_anchors_list, gt_labels, gt_bboxes, mask_gt, pred_bboxes.detach() * stride_tensor)
  494. else:
  495. target_labels, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
  496. pred_scores.detach().sigmoid(),
  497. (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  498. anchor_points * stride_tensor,
  499. gt_labels,
  500. gt_bboxes,
  501. mask_gt,
  502. )
  503. target_scores_sum = max(target_scores.sum(), 1)
  504. # Cls loss
  505. if isinstance(self.bce, (nn.BCEWithLogitsLoss, FocalLoss_YOLO)):
  506. loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
  507. elif isinstance(self.bce, VarifocalLoss_YOLO):
  508. if fg_mask.sum():
  509. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  510. # 10.0x Faster than torch.one_hot
  511. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  512. dtype=torch.int64,
  513. device=target_labels.device) # (b, h*w, 80)
  514. cls_iou_targets.scatter_(2, target_labels.unsqueeze(-1), 1)
  515. cls_iou_targets = pos_ious * cls_iou_targets
  516. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  517. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  518. else:
  519. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  520. dtype=torch.int64,
  521. device=target_labels.device) # (b, h*w, 80)
  522. loss[2] = self.bce(pred_scores, cls_iou_targets.to(dtype)).sum() / max(fg_mask.sum(), 1) # BCE
  523. elif isinstance(self.bce, QualityfocalLoss_YOLO):
  524. if fg_mask.sum():
  525. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  526. # 10.0x Faster than torch.one_hot
  527. targets_onehot = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  528. dtype=torch.int64,
  529. device=target_labels.device) # (b, h*w, 80)
  530. targets_onehot.scatter_(2, target_labels.unsqueeze(-1), 1)
  531. cls_iou_targets = pos_ious * targets_onehot
  532. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  533. targets_onehot_pos = torch.where(fg_scores_mask > 0, targets_onehot, 0)
  534. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  535. else:
  536. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  537. dtype=torch.int64,
  538. device=target_labels.device) # (b, h*w, 80)
  539. targets_onehot_pos = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  540. dtype=torch.int64,
  541. device=target_labels.device) # (b, h*w, 80)
  542. loss[2] = self.bce(pred_scores, cls_iou_targets.to(dtype), targets_onehot_pos.to(torch.bool)).sum() / max(fg_mask.sum(), 1)
  543. if fg_mask.sum():
  544. # Bbox loss
  545. loss[0], loss[3] = self.bbox_loss(
  546. pred_distri,
  547. pred_bboxes,
  548. anchor_points,
  549. target_bboxes / stride_tensor,
  550. target_scores,
  551. target_scores_sum,
  552. fg_mask,
  553. ((imgsz[0] ** 2 + imgsz[1] ** 2) / torch.square(stride_tensor)).repeat(1, batch_size).transpose(1, 0)
  554. )
  555. # Masks loss
  556. masks = batch["masks"].to(self.device).float()
  557. if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
  558. masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
  559. loss[1] = self.calculate_segmentation_loss(
  560. fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
  561. )
  562. # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
  563. else:
  564. loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
  565. if isinstance(self.bce, (EMASlideLoss, SlideLoss)):
  566. if fg_mask.sum():
  567. auto_iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True).mean()
  568. else:
  569. auto_iou = -1
  570. loss[2] = self.bce(pred_scores, target_scores.to(dtype), auto_iou).sum() / target_scores_sum # BCE
  571. loss[0] *= self.hyp.box # box gain
  572. loss[1] *= self.hyp.box # seg gain
  573. loss[2] *= self.hyp.cls # cls gain
  574. loss[3] *= self.hyp.dfl # dfl gain
  575. return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
  576. @staticmethod
  577. def single_mask_loss(
  578. gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
  579. ) -> torch.Tensor:
  580. """
  581. Compute the instance segmentation loss for a single image.
  582. Args:
  583. gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
  584. pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
  585. proto (torch.Tensor): Prototype masks of shape (32, H, W).
  586. xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
  587. area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
  588. Returns:
  589. (torch.Tensor): The calculated mask loss for a single image.
  590. Notes:
  591. The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
  592. predicted masks from the prototype masks and predicted mask coefficients.
  593. """
  594. pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
  595. loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
  596. return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
  597. def calculate_segmentation_loss(
  598. self,
  599. fg_mask: torch.Tensor,
  600. masks: torch.Tensor,
  601. target_gt_idx: torch.Tensor,
  602. target_bboxes: torch.Tensor,
  603. batch_idx: torch.Tensor,
  604. proto: torch.Tensor,
  605. pred_masks: torch.Tensor,
  606. imgsz: torch.Tensor,
  607. overlap: bool,
  608. ) -> torch.Tensor:
  609. """
  610. Calculate the loss for instance segmentation.
  611. Args:
  612. fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
  613. masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
  614. target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
  615. target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
  616. batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
  617. proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
  618. pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
  619. imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
  620. overlap (bool): Whether the masks in `masks` tensor overlap.
  621. Returns:
  622. (torch.Tensor): The calculated loss for instance segmentation.
  623. Notes:
  624. The batch loss can be computed for improved speed at higher memory usage.
  625. For example, pred_mask can be computed as follows:
  626. pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
  627. """
  628. _, _, mask_h, mask_w = proto.shape
  629. loss = 0
  630. # Normalize to 0-1
  631. target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
  632. # Areas of target bboxes
  633. marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
  634. # Normalize to mask size
  635. mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
  636. for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
  637. fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
  638. if fg_mask_i.any():
  639. mask_idx = target_gt_idx_i[fg_mask_i]
  640. if overlap:
  641. gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
  642. gt_mask = gt_mask.float()
  643. else:
  644. gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
  645. loss += self.single_mask_loss(
  646. gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
  647. )
  648. # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
  649. else:
  650. loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
  651. return loss / fg_mask.sum()
  652. class v8PoseLoss(v8DetectionLoss):
  653. """Criterion class for computing training losses."""
  654. def __init__(self, model): # model must be de-paralleled
  655. """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance."""
  656. super().__init__(model)
  657. self.kpt_shape = model.model[-1].kpt_shape
  658. self.bce_pose = nn.BCEWithLogitsLoss()
  659. is_pose = self.kpt_shape == [17, 3]
  660. nkpt = self.kpt_shape[0] # number of keypoints
  661. sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
  662. self.keypoint_loss = KeypointLoss(sigmas=sigmas)
  663. def __call__(self, preds, batch):
  664. """Calculate the total loss and detach it."""
  665. loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
  666. feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
  667. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
  668. (self.reg_max * 4, self.nc), 1
  669. )
  670. # B, grids, ..
  671. pred_scores = pred_scores.permute(0, 2, 1).contiguous()
  672. pred_distri = pred_distri.permute(0, 2, 1).contiguous()
  673. pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
  674. dtype = pred_scores.dtype
  675. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  676. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  677. # Targets
  678. batch_size = pred_scores.shape[0]
  679. batch_idx = batch["batch_idx"].view(-1, 1)
  680. targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
  681. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  682. gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
  683. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
  684. # Pboxes
  685. pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
  686. pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
  687. if isinstance(self.assigner, ATSSAssigner):
  688. anchors, _, n_anchors_list, _ = \
  689. generate_anchors(feats, self.fpn_strides, self.grid_cell_size, self.grid_cell_offset, device=feats[0].device)
  690. target_labels, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(anchors, n_anchors_list, gt_labels, gt_bboxes, mask_gt, pred_bboxes.detach() * stride_tensor)
  691. else:
  692. target_labels, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
  693. pred_scores.detach().sigmoid(),
  694. (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
  695. anchor_points * stride_tensor,
  696. gt_labels,
  697. gt_bboxes,
  698. mask_gt,
  699. )
  700. target_scores_sum = max(target_scores.sum(), 1)
  701. # Cls loss
  702. if isinstance(self.bce, (nn.BCEWithLogitsLoss, FocalLoss_YOLO)):
  703. loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
  704. elif isinstance(self.bce, VarifocalLoss_YOLO):
  705. if fg_mask.sum():
  706. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  707. # 10.0x Faster than torch.one_hot
  708. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  709. dtype=torch.int64,
  710. device=target_labels.device) # (b, h*w, 80)
  711. cls_iou_targets.scatter_(2, target_labels.unsqueeze(-1), 1)
  712. cls_iou_targets = pos_ious * cls_iou_targets
  713. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  714. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  715. else:
  716. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  717. dtype=torch.int64,
  718. device=target_labels.device) # (b, h*w, 80)
  719. loss[3] = self.bce(pred_scores, cls_iou_targets.to(dtype)).sum() / max(fg_mask.sum(), 1) # BCE
  720. elif isinstance(self.bce, QualityfocalLoss_YOLO):
  721. if fg_mask.sum():
  722. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  723. # 10.0x Faster than torch.one_hot
  724. targets_onehot = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  725. dtype=torch.int64,
  726. device=target_labels.device) # (b, h*w, 80)
  727. targets_onehot.scatter_(2, target_labels.unsqueeze(-1), 1)
  728. cls_iou_targets = pos_ious * targets_onehot
  729. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  730. targets_onehot_pos = torch.where(fg_scores_mask > 0, targets_onehot, 0)
  731. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  732. else:
  733. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  734. dtype=torch.int64,
  735. device=target_labels.device) # (b, h*w, 80)
  736. targets_onehot_pos = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  737. dtype=torch.int64,
  738. device=target_labels.device) # (b, h*w, 80)
  739. loss[3] = self.bce(pred_scores, cls_iou_targets.to(dtype), targets_onehot_pos.to(torch.bool)).sum() / max(fg_mask.sum(), 1)
  740. # Bbox loss
  741. if fg_mask.sum():
  742. target_bboxes /= stride_tensor
  743. loss[0], loss[4] = self.bbox_loss(
  744. pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask,
  745. ((imgsz[0] ** 2 + imgsz[1] ** 2) / torch.square(stride_tensor)).repeat(1, batch_size).transpose(1, 0)
  746. )
  747. keypoints = batch["keypoints"].to(self.device).float().clone()
  748. keypoints[..., 0] *= imgsz[1]
  749. keypoints[..., 1] *= imgsz[0]
  750. loss[1], loss[2] = self.calculate_keypoints_loss(
  751. fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
  752. )
  753. if isinstance(self.bce, (EMASlideLoss, SlideLoss)):
  754. if fg_mask.sum():
  755. auto_iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True).mean()
  756. else:
  757. auto_iou = -1
  758. loss[3] = self.bce(pred_scores, target_scores.to(dtype), auto_iou).sum() / target_scores_sum # BCE
  759. loss[0] *= self.hyp.box # box gain
  760. loss[1] *= self.hyp.pose # pose gain
  761. loss[2] *= self.hyp.kobj # kobj gain
  762. loss[3] *= self.hyp.cls # cls gain
  763. loss[4] *= self.hyp.dfl # dfl gain
  764. return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
  765. @staticmethod
  766. def kpts_decode(anchor_points, pred_kpts):
  767. """Decodes predicted keypoints to image coordinates."""
  768. y = pred_kpts.clone()
  769. y[..., :2] *= 2.0
  770. y[..., 0] += anchor_points[:, [0]] - 0.5
  771. y[..., 1] += anchor_points[:, [1]] - 0.5
  772. return y
  773. def calculate_keypoints_loss(
  774. self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
  775. ):
  776. """
  777. Calculate the keypoints loss for the model.
  778. This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
  779. based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
  780. a binary classification loss that classifies whether a keypoint is present or not.
  781. Args:
  782. masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
  783. target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
  784. keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
  785. batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
  786. stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
  787. target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
  788. pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
  789. Returns:
  790. (tuple): Returns a tuple containing:
  791. - kpts_loss (torch.Tensor): The keypoints loss.
  792. - kpts_obj_loss (torch.Tensor): The keypoints object loss.
  793. """
  794. batch_idx = batch_idx.flatten()
  795. batch_size = len(masks)
  796. # Find the maximum number of keypoints in a single image
  797. max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
  798. # Create a tensor to hold batched keypoints
  799. batched_keypoints = torch.zeros(
  800. (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
  801. )
  802. # TODO: any idea how to vectorize this?
  803. # Fill batched_keypoints with keypoints based on batch_idx
  804. for i in range(batch_size):
  805. keypoints_i = keypoints[batch_idx == i]
  806. batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
  807. # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
  808. target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
  809. # Use target_gt_idx_expanded to select keypoints from batched_keypoints
  810. selected_keypoints = batched_keypoints.gather(
  811. 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
  812. )
  813. # Divide coordinates by stride
  814. selected_keypoints /= stride_tensor.view(1, -1, 1, 1)
  815. kpts_loss = 0
  816. kpts_obj_loss = 0
  817. if masks.any():
  818. gt_kpt = selected_keypoints[masks]
  819. area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
  820. pred_kpt = pred_kpts[masks]
  821. kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
  822. kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
  823. if pred_kpt.shape[-1] == 3:
  824. kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
  825. return kpts_loss, kpts_obj_loss
  826. class v8ClassificationLoss:
  827. """Criterion class for computing training losses."""
  828. def __call__(self, preds, batch):
  829. """Compute the classification loss between predictions and true labels."""
  830. loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
  831. loss_items = loss.detach()
  832. return loss, loss_items
  833. class v8OBBLoss(v8DetectionLoss):
  834. def __init__(self, model):
  835. """
  836. Initializes v8OBBLoss with model, assigner, and rotated bbox loss.
  837. Note model must be de-paralleled.
  838. """
  839. super().__init__(model)
  840. self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
  841. self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
  842. def preprocess(self, targets, batch_size, scale_tensor):
  843. """Preprocesses the target counts and matches with the input batch size to output a tensor."""
  844. if targets.shape[0] == 0:
  845. out = torch.zeros(batch_size, 0, 6, device=self.device)
  846. else:
  847. i = targets[:, 0] # image index
  848. _, counts = i.unique(return_counts=True)
  849. counts = counts.to(dtype=torch.int32)
  850. out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
  851. for j in range(batch_size):
  852. matches = i == j
  853. n = matches.sum()
  854. if n:
  855. bboxes = targets[matches, 2:]
  856. bboxes[..., :4].mul_(scale_tensor)
  857. out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
  858. return out
  859. def __call__(self, preds, batch):
  860. """Calculate and return the loss for the YOLO model."""
  861. loss = torch.zeros(3, device=self.device) # box, cls, dfl
  862. feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
  863. batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
  864. pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
  865. (self.reg_max * 4, self.nc), 1
  866. )
  867. # b, grids, ..
  868. pred_scores = pred_scores.permute(0, 2, 1).contiguous()
  869. pred_distri = pred_distri.permute(0, 2, 1).contiguous()
  870. pred_angle = pred_angle.permute(0, 2, 1).contiguous()
  871. dtype = pred_scores.dtype
  872. imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
  873. anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
  874. # targets
  875. try:
  876. batch_idx = batch["batch_idx"].view(-1, 1)
  877. targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
  878. rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
  879. targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
  880. targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
  881. gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
  882. mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
  883. except RuntimeError as e:
  884. raise TypeError(
  885. "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
  886. "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
  887. "i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a "
  888. "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
  889. "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
  890. ) from e
  891. # Pboxes
  892. pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
  893. bboxes_for_assigner = pred_bboxes.clone().detach()
  894. # Only the first four elements need to be scaled
  895. bboxes_for_assigner[..., :4] *= stride_tensor
  896. target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
  897. pred_scores.detach().sigmoid(),
  898. bboxes_for_assigner.type(gt_bboxes.dtype),
  899. anchor_points * stride_tensor,
  900. gt_labels,
  901. gt_bboxes,
  902. mask_gt,
  903. )
  904. target_scores_sum = max(target_scores.sum(), 1)
  905. # Cls loss
  906. if isinstance(self.bce, (nn.BCEWithLogitsLoss, FocalLoss_YOLO)):
  907. loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
  908. elif isinstance(self.bce, VarifocalLoss_YOLO):
  909. if fg_mask.sum():
  910. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  911. # 10.0x Faster than torch.one_hot
  912. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  913. dtype=torch.int64,
  914. device=target_labels.device) # (b, h*w, 80)
  915. cls_iou_targets.scatter_(2, target_labels.unsqueeze(-1), 1)
  916. cls_iou_targets = pos_ious * cls_iou_targets
  917. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  918. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  919. else:
  920. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  921. dtype=torch.int64,
  922. device=target_labels.device) # (b, h*w, 80)
  923. loss[1] = self.bce(pred_scores, cls_iou_targets.to(dtype)).sum() / max(fg_mask.sum(), 1) # BCE
  924. elif isinstance(self.bce, QualityfocalLoss_YOLO):
  925. if fg_mask.sum():
  926. pos_ious = bbox_iou(pred_bboxes, target_bboxes / stride_tensor, xywh=False).clamp(min=1e-6).detach()
  927. # 10.0x Faster than torch.one_hot
  928. targets_onehot = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  929. dtype=torch.int64,
  930. device=target_labels.device) # (b, h*w, 80)
  931. targets_onehot.scatter_(2, target_labels.unsqueeze(-1), 1)
  932. cls_iou_targets = pos_ious * targets_onehot
  933. fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) # (b, h*w, 80)
  934. targets_onehot_pos = torch.where(fg_scores_mask > 0, targets_onehot, 0)
  935. cls_iou_targets = torch.where(fg_scores_mask > 0, cls_iou_targets, 0)
  936. else:
  937. cls_iou_targets = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  938. dtype=torch.int64,
  939. device=target_labels.device) # (b, h*w, 80)
  940. targets_onehot_pos = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.nc),
  941. dtype=torch.int64,
  942. device=target_labels.device) # (b, h*w, 80)
  943. loss[1] = self.bce(pred_scores, cls_iou_targets.to(dtype), targets_onehot_pos.to(torch.bool)).sum() / max(fg_mask.sum(), 1)
  944. # Bbox loss
  945. if fg_mask.sum():
  946. target_bboxes[..., :4] /= stride_tensor
  947. loss[0], loss[2] = self.bbox_loss(
  948. pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
  949. )
  950. else:
  951. loss[0] += (pred_angle * 0).sum()
  952. loss[0] *= self.hyp.box # box gain
  953. loss[1] *= self.hyp.cls # cls gain
  954. loss[2] *= self.hyp.dfl # dfl gain
  955. return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
  956. def bbox_decode(self, anchor_points, pred_dist, pred_angle):
  957. """
  958. Decode predicted object bounding box coordinates from anchor points and distribution.
  959. Args:
  960. anchor_points (torch.Tensor): Anchor points, (h*w, 2).
  961. pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
  962. pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
  963. Returns:
  964. (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
  965. """
  966. if self.use_dfl:
  967. b, a, c = pred_dist.shape # batch, anchors, channels
  968. pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
  969. return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
  970. class E2EDetectLoss:
  971. """Criterion class for computing training losses."""
  972. def __init__(self, model):
  973. """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
  974. self.one2many = v8DetectionLoss(model, tal_topk=10)
  975. self.one2one = v8DetectionLoss(model, tal_topk=1)
  976. def __call__(self, preds, batch):
  977. """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
  978. preds = preds[1] if isinstance(preds, tuple) else preds
  979. one2many = preds["one2many"]
  980. loss_one2many = self.one2many(one2many, batch)
  981. one2one = preds["one2one"]
  982. loss_one2one = self.one2one(one2one, batch)
  983. return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]