ops.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import contextlib
  3. import math
  4. import re
  5. import time
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import torch.nn.functional as F
  10. from ultralytics.utils import LOGGER
  11. from ultralytics.utils.metrics import batch_probiou
  12. class Profile(contextlib.ContextDecorator):
  13. """
  14. YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'.
  15. Example:
  16. ```python
  17. from ultralytics.utils.ops import Profile
  18. with Profile(device=device) as dt:
  19. pass # slow operation here
  20. print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
  21. ```
  22. """
  23. def __init__(self, t=0.0, device: torch.device = None):
  24. """
  25. Initialize the Profile class.
  26. Args:
  27. t (float): Initial time. Defaults to 0.0.
  28. device (torch.device): Devices used for model inference. Defaults to None (cpu).
  29. """
  30. self.t = t
  31. self.device = device
  32. self.cuda = bool(device and str(device).startswith("cuda"))
  33. def __enter__(self):
  34. """Start timing."""
  35. self.start = self.time()
  36. return self
  37. def __exit__(self, type, value, traceback): # noqa
  38. """Stop timing."""
  39. self.dt = self.time() - self.start # delta-time
  40. self.t += self.dt # accumulate dt
  41. def __str__(self):
  42. """Returns a human-readable string representing the accumulated elapsed time in the profiler."""
  43. return f"Elapsed time is {self.t} s"
  44. def time(self):
  45. """Get current time."""
  46. if self.cuda:
  47. torch.cuda.synchronize(self.device)
  48. return time.time()
  49. def segment2box(segment, width=640, height=640):
  50. """
  51. Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy).
  52. Args:
  53. segment (torch.Tensor): the segment label
  54. width (int): the width of the image. Defaults to 640
  55. height (int): The height of the image. Defaults to 640
  56. Returns:
  57. (np.ndarray): the minimum and maximum x and y values of the segment.
  58. """
  59. x, y = segment.T # segment xy
  60. inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
  61. x = x[inside]
  62. y = y[inside]
  63. return (
  64. np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
  65. if any(x)
  66. else np.zeros(4, dtype=segment.dtype)
  67. ) # xyxy
  68. def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False):
  69. """
  70. Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally
  71. specified in (img1_shape) to the shape of a different image (img0_shape).
  72. Args:
  73. img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
  74. boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
  75. img0_shape (tuple): the shape of the target image, in the format of (height, width).
  76. ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
  77. calculated based on the size difference between the two images.
  78. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  79. rescaling.
  80. xywh (bool): The box format is xywh or not, default=False.
  81. Returns:
  82. boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
  83. """
  84. if ratio_pad is None: # calculate from img0_shape
  85. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  86. pad = (
  87. round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1),
  88. round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1),
  89. ) # wh padding
  90. else:
  91. gain = ratio_pad[0][0]
  92. pad = ratio_pad[1]
  93. if padding:
  94. boxes[..., 0] -= pad[0] # x padding
  95. boxes[..., 1] -= pad[1] # y padding
  96. if not xywh:
  97. boxes[..., 2] -= pad[0] # x padding
  98. boxes[..., 3] -= pad[1] # y padding
  99. boxes[..., :4] /= gain
  100. return clip_boxes(boxes, img0_shape)
  101. def make_divisible(x, divisor):
  102. """
  103. Returns the nearest number that is divisible by the given divisor.
  104. Args:
  105. x (int): The number to make divisible.
  106. divisor (int | torch.Tensor): The divisor.
  107. Returns:
  108. (int): The nearest number divisible by the divisor.
  109. """
  110. if isinstance(divisor, torch.Tensor):
  111. divisor = int(divisor.max()) # to int
  112. return math.ceil(x / divisor) * divisor
  113. def nms_rotated(boxes, scores, threshold=0.45):
  114. """
  115. NMS for obbs, powered by probiou and fast-nms.
  116. Args:
  117. boxes (torch.Tensor): (N, 5), xywhr.
  118. scores (torch.Tensor): (N, ).
  119. threshold (float): IoU threshold.
  120. Returns:
  121. """
  122. if len(boxes) == 0:
  123. return np.empty((0,), dtype=np.int8)
  124. sorted_idx = torch.argsort(scores, descending=True)
  125. boxes = boxes[sorted_idx]
  126. ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
  127. pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
  128. return sorted_idx[pick]
  129. def bbox_iou_for_nms(box1, box2, xywh=False, GIoU=False, DIoU=False, CIoU=False, EIoU=False, SIoU=False, ShapeIoU=False, eps=1e-7, scale=0.0):
  130. """
  131. Calculate Intersection over Union (IoU) of box1(1, 4) to box2(n, 4).
  132. Args:
  133. box1 (torch.Tensor): A tensor representing a single bounding box with shape (1, 4).
  134. box2 (torch.Tensor): A tensor representing n bounding boxes with shape (n, 4).
  135. xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
  136. (x1, y1, x2, y2) format. Defaults to True.
  137. GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False.
  138. DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False.
  139. CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False.
  140. EIoU (bool, optional): If True, calculate Efficient IoU. Defaults to False.
  141. SIoU (bool, optional): If True, calculate Scylla IoU. Defaults to False.
  142. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
  143. Returns:
  144. (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
  145. """
  146. # Get the coordinates of bounding boxes
  147. if xywh: # transform from xywh to xyxy
  148. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  149. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  150. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  151. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  152. else: # x1, y1, x2, y2 = box1
  153. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  154. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  155. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  156. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  157. # Intersection area
  158. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * \
  159. (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp_(0)
  160. # Union Area
  161. union = w1 * h1 + w2 * h2 - inter + eps
  162. # IoU
  163. iou = inter / union
  164. if CIoU or DIoU or GIoU or EIoU or SIoU or ShapeIoU:
  165. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  166. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  167. if CIoU or DIoU or EIoU or SIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  168. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  169. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
  170. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  171. v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
  172. with torch.no_grad():
  173. alpha = v / (v - iou + (1 + eps))
  174. return iou - (rho2 / c2 + v * alpha) # CIoU
  175. elif EIoU:
  176. rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
  177. rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
  178. cw2 = cw ** 2 + eps
  179. ch2 = ch ** 2 + eps
  180. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIoU
  181. elif SIoU:
  182. # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
  183. s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
  184. s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
  185. sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
  186. sin_alpha_1 = torch.abs(s_cw) / sigma
  187. sin_alpha_2 = torch.abs(s_ch) / sigma
  188. threshold = pow(2, 0.5) / 2
  189. sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
  190. angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
  191. rho_x = (s_cw / cw) ** 2
  192. rho_y = (s_ch / ch) ** 2
  193. gamma = angle_cost - 2
  194. distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
  195. omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
  196. omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
  197. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
  198. return iou - 0.5 * (distance_cost + shape_cost) + eps # SIoU
  199. elif ShapeIoU:
  200. #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance #Shape-Distance
  201. ww = 2 * torch.pow(w2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
  202. hh = 2 * torch.pow(h2, scale) / (torch.pow(w2, scale) + torch.pow(h2, scale))
  203. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex width
  204. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  205. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  206. center_distance_x = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2) / 4
  207. center_distance_y = ((b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
  208. center_distance = hh * center_distance_x + ww * center_distance_y
  209. distance = center_distance / c2
  210. #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape #Shape-Shape
  211. omiga_w = hh * torch.abs(w1 - w2) / torch.max(w1, w2)
  212. omiga_h = ww * torch.abs(h1 - h2) / torch.max(h1, h2)
  213. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
  214. return iou - distance - 0.5 * shape_cost
  215. return iou - rho2 / c2 # DIoU
  216. c_area = cw * ch + eps # convex area
  217. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  218. return iou # IoU
  219. def soft_nms(bboxes, scores, iou_thresh=0.5, sigma=0.5,score_threshold=0.25):
  220. order = torch.arange(0, scores.size(0)).to(bboxes.device)
  221. keep = []
  222. while order.numel() > 1:
  223. if order.numel() == 1:
  224. keep.append(order[0])
  225. break
  226. else:
  227. i = order[0]
  228. keep.append(i)
  229. iou = bbox_iou_for_nms(bboxes[i:i+1], bboxes[order[1:]], GIoU=False, DIoU=False, CIoU=False, EIoU=False, SIoU=False, ShapeIoU=False, scale=0.0).squeeze()
  230. idx = (iou > iou_thresh).nonzero().squeeze()
  231. if idx.numel() > 0:
  232. iou = iou[idx]
  233. newScores = torch.exp(-torch.pow(iou,2)/sigma)
  234. scores[order[idx+1]] *= newScores
  235. newOrder = (scores[order[1:]] > score_threshold).nonzero().squeeze()
  236. if newOrder.numel() == 0:
  237. break
  238. else:
  239. maxScoreIndex = torch.argmax(scores[order[newOrder+1]])
  240. if maxScoreIndex != 0:
  241. newOrder[[0,maxScoreIndex],] = newOrder[[maxScoreIndex,0],]
  242. order = order[newOrder+1]
  243. return torch.LongTensor(keep)
  244. def non_max_suppression(
  245. prediction,
  246. conf_thres=0.25,
  247. iou_thres=0.45,
  248. classes=None,
  249. agnostic=False,
  250. multi_label=False,
  251. labels=(),
  252. max_det=300,
  253. nc=0, # number of classes (optional)
  254. max_time_img=0.05,
  255. max_nms=30000,
  256. max_wh=7680,
  257. in_place=True,
  258. rotated=False,
  259. ):
  260. """
  261. Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
  262. Args:
  263. prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  264. containing the predicted boxes, classes, and masks. The tensor should be in the format
  265. output by a model, such as YOLO.
  266. conf_thres (float): The confidence threshold below which boxes will be filtered out.
  267. Valid values are between 0.0 and 1.0.
  268. iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
  269. Valid values are between 0.0 and 1.0.
  270. classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
  271. agnostic (bool): If True, the model is agnostic to the number of classes, and all
  272. classes will be considered as one.
  273. multi_label (bool): If True, each box may have multiple labels.
  274. labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
  275. list contains the apriori labels for a given image. The list should be in the format
  276. output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
  277. max_det (int): The maximum number of boxes to keep after NMS.
  278. nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
  279. max_time_img (float): The maximum time (seconds) for processing one image.
  280. max_nms (int): The maximum number of boxes into torchvision.ops.nms().
  281. max_wh (int): The maximum box width and height in pixels.
  282. in_place (bool): If True, the input prediction tensor will be modified in place.
  283. rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
  284. Returns:
  285. (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
  286. shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
  287. (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  288. """
  289. import torchvision # scope for faster 'import ultralytics'
  290. # Checks
  291. assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
  292. assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
  293. if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
  294. prediction = prediction[0] # select only inference output
  295. if classes is not None:
  296. classes = torch.tensor(classes, device=prediction.device)
  297. if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
  298. output = [pred[pred[:, 4] > conf_thres] for pred in prediction]
  299. if classes is not None:
  300. output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
  301. return output
  302. bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
  303. nc = nc or (prediction.shape[1] - 4) # number of classes
  304. nm = prediction.shape[1] - nc - 4 # number of masks
  305. mi = 4 + nc # mask start index
  306. xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
  307. # Settings
  308. # min_wh = 2 # (pixels) minimum box width and height
  309. time_limit = 2.0 + max_time_img * bs # seconds to quit after
  310. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  311. prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
  312. if not rotated:
  313. if in_place:
  314. prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
  315. else:
  316. prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy
  317. t = time.time()
  318. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  319. for xi, x in enumerate(prediction): # image index, image inference
  320. # Apply constraints
  321. # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
  322. x = x[xc[xi]] # confidence
  323. # Cat apriori labels if autolabelling
  324. if labels and len(labels[xi]) and not rotated:
  325. lb = labels[xi]
  326. v = torch.zeros((len(lb), nc + nm + 4), device=x.device)
  327. v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
  328. v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
  329. x = torch.cat((x, v), 0)
  330. # If none remain process next image
  331. if not x.shape[0]:
  332. continue
  333. # Detections matrix nx6 (xyxy, conf, cls)
  334. box, cls, mask = x.split((4, nc, nm), 1)
  335. if multi_label:
  336. i, j = torch.where(cls > conf_thres)
  337. x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
  338. else: # best class only
  339. conf, j = cls.max(1, keepdim=True)
  340. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  341. # Filter by class
  342. if classes is not None:
  343. x = x[(x[:, 5:6] == classes).any(1)]
  344. # Check shape
  345. n = x.shape[0] # number of boxes
  346. if not n: # no boxes
  347. continue
  348. if n > max_nms: # excess boxes
  349. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  350. # Batched NMS
  351. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  352. scores = x[:, 4] # scores
  353. if rotated:
  354. boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
  355. i = nms_rotated(boxes, scores, iou_thres)
  356. else:
  357. boxes = x[:, :4] + c # boxes (offset by class)
  358. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  359. # i = soft_nms(boxes, scores, iou_thres) $ Soft-NMS
  360. i = i[:max_det] # limit detections
  361. # # Experimental
  362. # merge = False # use merge-NMS
  363. # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  364. # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  365. # from .metrics import box_iou
  366. # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix
  367. # weights = iou * scores[None] # box weights
  368. # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  369. # redundant = True # require redundant detections
  370. # if redundant:
  371. # i = i[iou.sum(1) > 1] # require redundancy
  372. output[xi] = x[i]
  373. if (time.time() - t) > time_limit:
  374. LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
  375. break # time limit exceeded
  376. return output
  377. def clip_boxes(boxes, shape):
  378. """
  379. Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
  380. Args:
  381. boxes (torch.Tensor): the bounding boxes to clip
  382. shape (tuple): the shape of the image
  383. Returns:
  384. (torch.Tensor | numpy.ndarray): Clipped boxes
  385. """
  386. if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
  387. boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
  388. boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1
  389. boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2
  390. boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2
  391. else: # np.array (faster grouped)
  392. boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2
  393. boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2
  394. return boxes
  395. def clip_coords(coords, shape):
  396. """
  397. Clip line coordinates to the image boundaries.
  398. Args:
  399. coords (torch.Tensor | numpy.ndarray): A list of line coordinates.
  400. shape (tuple): A tuple of integers representing the size of the image in the format (height, width).
  401. Returns:
  402. (torch.Tensor | numpy.ndarray): Clipped coordinates
  403. """
  404. if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
  405. coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x
  406. coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y
  407. else: # np.array (faster grouped)
  408. coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x
  409. coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y
  410. return coords
  411. def scale_image(masks, im0_shape, ratio_pad=None):
  412. """
  413. Takes a mask, and resizes it to the original image size.
  414. Args:
  415. masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
  416. im0_shape (tuple): the original image shape
  417. ratio_pad (tuple): the ratio of the padding to the original image.
  418. Returns:
  419. masks (torch.Tensor): The masks that are being returned.
  420. """
  421. # Rescale coordinates (xyxy) from im1_shape to im0_shape
  422. im1_shape = masks.shape
  423. if im1_shape[:2] == im0_shape[:2]:
  424. return masks
  425. if ratio_pad is None: # calculate from im0_shape
  426. gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new
  427. pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding
  428. else:
  429. # gain = ratio_pad[0][0]
  430. pad = ratio_pad[1]
  431. top, left = int(pad[1]), int(pad[0]) # y, x
  432. bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0])
  433. if len(masks.shape) < 2:
  434. raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
  435. masks = masks[top:bottom, left:right]
  436. masks = cv2.resize(masks, (im0_shape[1], im0_shape[0]))
  437. if len(masks.shape) == 2:
  438. masks = masks[:, :, None]
  439. return masks
  440. def xyxy2xywh(x):
  441. """
  442. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
  443. top-left corner and (x2, y2) is the bottom-right corner.
  444. Args:
  445. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
  446. Returns:
  447. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
  448. """
  449. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  450. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  451. y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
  452. y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
  453. y[..., 2] = x[..., 2] - x[..., 0] # width
  454. y[..., 3] = x[..., 3] - x[..., 1] # height
  455. return y
  456. def xywh2xyxy(x):
  457. """
  458. Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
  459. top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
  460. Args:
  461. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
  462. Returns:
  463. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
  464. """
  465. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  466. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  467. xy = x[..., :2] # centers
  468. wh = x[..., 2:] / 2 # half width-height
  469. y[..., :2] = xy - wh # top left xy
  470. y[..., 2:] = xy + wh # bottom right xy
  471. return y
  472. def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
  473. """
  474. Convert normalized bounding box coordinates to pixel coordinates.
  475. Args:
  476. x (np.ndarray | torch.Tensor): The bounding box coordinates.
  477. w (int): Width of the image. Defaults to 640
  478. h (int): Height of the image. Defaults to 640
  479. padw (int): Padding width. Defaults to 0
  480. padh (int): Padding height. Defaults to 0
  481. Returns:
  482. y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
  483. x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
  484. """
  485. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  486. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  487. y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
  488. y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
  489. y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
  490. y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
  491. return y
  492. def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
  493. """
  494. Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
  495. width and height are normalized to image dimensions.
  496. Args:
  497. x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
  498. w (int): The width of the image. Defaults to 640
  499. h (int): The height of the image. Defaults to 640
  500. clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
  501. eps (float): The minimum value of the box's width and height. Defaults to 0.0
  502. Returns:
  503. y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
  504. """
  505. if clip:
  506. x = clip_boxes(x, (h - eps, w - eps))
  507. assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
  508. y = torch.empty_like(x) if isinstance(x, torch.Tensor) else np.empty_like(x) # faster than clone/copy
  509. y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
  510. y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
  511. y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
  512. y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
  513. return y
  514. def xywh2ltwh(x):
  515. """
  516. Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
  517. Args:
  518. x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
  519. Returns:
  520. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format
  521. """
  522. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  523. y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
  524. y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
  525. return y
  526. def xyxy2ltwh(x):
  527. """
  528. Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right.
  529. Args:
  530. x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
  531. Returns:
  532. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format.
  533. """
  534. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  535. y[..., 2] = x[..., 2] - x[..., 0] # width
  536. y[..., 3] = x[..., 3] - x[..., 1] # height
  537. return y
  538. def ltwh2xywh(x):
  539. """
  540. Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
  541. Args:
  542. x (torch.Tensor): the input tensor
  543. Returns:
  544. y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format.
  545. """
  546. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  547. y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
  548. y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
  549. return y
  550. def xyxyxyxy2xywhr(x):
  551. """
  552. Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are
  553. expected in degrees from 0 to 90.
  554. Args:
  555. x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8).
  556. Returns:
  557. (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5).
  558. """
  559. is_torch = isinstance(x, torch.Tensor)
  560. points = x.cpu().numpy() if is_torch else x
  561. points = points.reshape(len(x), -1, 2)
  562. rboxes = []
  563. for pts in points:
  564. # NOTE: Use cv2.minAreaRect to get accurate xywhr,
  565. # especially some objects are cut off by augmentations in dataloader.
  566. (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
  567. rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
  568. return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
  569. def xywhr2xyxyxyxy(x):
  570. """
  571. Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
  572. be in degrees from 0 to 90.
  573. Args:
  574. x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5).
  575. Returns:
  576. (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
  577. """
  578. cos, sin, cat, stack = (
  579. (torch.cos, torch.sin, torch.cat, torch.stack)
  580. if isinstance(x, torch.Tensor)
  581. else (np.cos, np.sin, np.concatenate, np.stack)
  582. )
  583. ctr = x[..., :2]
  584. w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
  585. cos_value, sin_value = cos(angle), sin(angle)
  586. vec1 = [w / 2 * cos_value, w / 2 * sin_value]
  587. vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
  588. vec1 = cat(vec1, -1)
  589. vec2 = cat(vec2, -1)
  590. pt1 = ctr + vec1 + vec2
  591. pt2 = ctr + vec1 - vec2
  592. pt3 = ctr - vec1 - vec2
  593. pt4 = ctr - vec1 + vec2
  594. return stack([pt1, pt2, pt3, pt4], -2)
  595. def ltwh2xyxy(x):
  596. """
  597. It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
  598. Args:
  599. x (np.ndarray | torch.Tensor): the input image
  600. Returns:
  601. y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes.
  602. """
  603. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  604. y[..., 2] = x[..., 2] + x[..., 0] # width
  605. y[..., 3] = x[..., 3] + x[..., 1] # height
  606. return y
  607. def segments2boxes(segments):
  608. """
  609. It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
  610. Args:
  611. segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
  612. Returns:
  613. (np.ndarray): the xywh coordinates of the bounding boxes.
  614. """
  615. boxes = []
  616. for s in segments:
  617. x, y = s.T # segment xy
  618. boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
  619. return xyxy2xywh(np.array(boxes)) # cls, xywh
  620. def resample_segments(segments, n=1000):
  621. """
  622. Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
  623. Args:
  624. segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
  625. n (int): number of points to resample the segment to. Defaults to 1000
  626. Returns:
  627. segments (list): the resampled segments.
  628. """
  629. for i, s in enumerate(segments):
  630. s = np.concatenate((s, s[0:1, :]), axis=0)
  631. x = np.linspace(0, len(s) - 1, n)
  632. xp = np.arange(len(s))
  633. segments[i] = (
  634. np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
  635. ) # segment xy
  636. return segments
  637. def crop_mask(masks, boxes):
  638. """
  639. It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
  640. Args:
  641. masks (torch.Tensor): [n, h, w] tensor of masks
  642. boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
  643. Returns:
  644. (torch.Tensor): The masks are being cropped to the bounding box.
  645. """
  646. _, h, w = masks.shape
  647. x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
  648. r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
  649. c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
  650. return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
  651. def process_mask_upsample(protos, masks_in, bboxes, shape):
  652. """
  653. Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher quality
  654. but is slower.
  655. Args:
  656. protos (torch.Tensor): [mask_dim, mask_h, mask_w]
  657. masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
  658. bboxes (torch.Tensor): [n, 4], n is number of masks after nms
  659. shape (tuple): the size of the input image (h,w)
  660. Returns:
  661. (torch.Tensor): The upsampled masks.
  662. """
  663. c, mh, mw = protos.shape # CHW
  664. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
  665. masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
  666. masks = crop_mask(masks, bboxes) # CHW
  667. return masks.gt_(0.0)
  668. def process_mask(protos, masks_in, bboxes, shape, upsample=False):
  669. """
  670. Apply masks to bounding boxes using the output of the mask head.
  671. Args:
  672. protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w].
  673. masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS.
  674. bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS.
  675. shape (tuple): A tuple of integers representing the size of the input image in the format (h, w).
  676. upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False.
  677. Returns:
  678. (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
  679. are the height and width of the input image. The mask is applied to the bounding boxes.
  680. """
  681. c, mh, mw = protos.shape # CHW
  682. ih, iw = shape
  683. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
  684. width_ratio = mw / iw
  685. height_ratio = mh / ih
  686. downsampled_bboxes = bboxes.clone()
  687. downsampled_bboxes[:, 0] *= width_ratio
  688. downsampled_bboxes[:, 2] *= width_ratio
  689. downsampled_bboxes[:, 3] *= height_ratio
  690. downsampled_bboxes[:, 1] *= height_ratio
  691. masks = crop_mask(masks, downsampled_bboxes) # CHW
  692. if upsample:
  693. masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
  694. return masks.gt_(0.0)
  695. def process_mask_native(protos, masks_in, bboxes, shape):
  696. """
  697. It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
  698. Args:
  699. protos (torch.Tensor): [mask_dim, mask_h, mask_w]
  700. masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
  701. bboxes (torch.Tensor): [n, 4], n is number of masks after nms
  702. shape (tuple): the size of the input image (h,w)
  703. Returns:
  704. masks (torch.Tensor): The returned masks with dimensions [h, w, n]
  705. """
  706. c, mh, mw = protos.shape # CHW
  707. masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
  708. masks = scale_masks(masks[None], shape)[0] # CHW
  709. masks = crop_mask(masks, bboxes) # CHW
  710. return masks.gt_(0.0)
  711. def scale_masks(masks, shape, padding=True):
  712. """
  713. Rescale segment masks to shape.
  714. Args:
  715. masks (torch.Tensor): (N, C, H, W).
  716. shape (tuple): Height and width.
  717. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  718. rescaling.
  719. """
  720. mh, mw = masks.shape[2:]
  721. gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
  722. pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding
  723. if padding:
  724. pad[0] /= 2
  725. pad[1] /= 2
  726. top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x
  727. bottom, right = (int(mh - pad[1]), int(mw - pad[0]))
  728. masks = masks[..., top:bottom, left:right]
  729. masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW
  730. return masks
  731. def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True):
  732. """
  733. Rescale segment coordinates (xy) from img1_shape to img0_shape.
  734. Args:
  735. img1_shape (tuple): The shape of the image that the coords are from.
  736. coords (torch.Tensor): the coords to be scaled of shape n,2.
  737. img0_shape (tuple): the shape of the image that the segmentation is being applied to.
  738. ratio_pad (tuple): the ratio of the image size to the padded image size.
  739. normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False.
  740. padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular
  741. rescaling.
  742. Returns:
  743. coords (torch.Tensor): The scaled coordinates.
  744. """
  745. if ratio_pad is None: # calculate from img0_shape
  746. gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
  747. pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
  748. else:
  749. gain = ratio_pad[0][0]
  750. pad = ratio_pad[1]
  751. if padding:
  752. coords[..., 0] -= pad[0] # x padding
  753. coords[..., 1] -= pad[1] # y padding
  754. coords[..., 0] /= gain
  755. coords[..., 1] /= gain
  756. coords = clip_coords(coords, img0_shape)
  757. if normalize:
  758. coords[..., 0] /= img0_shape[1] # width
  759. coords[..., 1] /= img0_shape[0] # height
  760. return coords
  761. def regularize_rboxes(rboxes):
  762. """
  763. Regularize rotated boxes in range [0, pi/2].
  764. Args:
  765. rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format.
  766. Returns:
  767. (torch.Tensor): The regularized boxes.
  768. """
  769. x, y, w, h, t = rboxes.unbind(dim=-1)
  770. # Swap edge and angle if h >= w
  771. w_ = torch.where(w > h, w, h)
  772. h_ = torch.where(w > h, h, w)
  773. t = torch.where(w > h, t, t + math.pi / 2) % math.pi
  774. return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
  775. def masks2segments(masks, strategy="largest"):
  776. """
  777. It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
  778. Args:
  779. masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
  780. strategy (str): 'concat' or 'largest'. Defaults to largest
  781. Returns:
  782. segments (List): list of segment masks
  783. """
  784. segments = []
  785. for x in masks.int().cpu().numpy().astype("uint8"):
  786. c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
  787. if c:
  788. if strategy == "concat": # concatenate all segments
  789. c = np.concatenate([x.reshape(-1, 2) for x in c])
  790. elif strategy == "largest": # select largest segment
  791. c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
  792. else:
  793. c = np.zeros((0, 2)) # no segments found
  794. segments.append(c.astype("float32"))
  795. return segments
  796. def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
  797. """
  798. Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout.
  799. Args:
  800. batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32.
  801. Returns:
  802. (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8.
  803. """
  804. return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
  805. def clean_str(s):
  806. """
  807. Cleans a string by replacing special characters with underscore _
  808. Args:
  809. s (str): a string needing special characters replaced
  810. Returns:
  811. (str): a string with special characters replaced by an underscore _
  812. """
  813. return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)