utils.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
  4. """
  5. Adjust bounding boxes to stick to image border if they are within a certain threshold.
  6. Args:
  7. boxes (torch.Tensor): (n, 4)
  8. image_shape (tuple): (height, width)
  9. threshold (int): pixel threshold
  10. Returns:
  11. adjusted_boxes (torch.Tensor): adjusted bounding boxes
  12. """
  13. # Image dimensions
  14. h, w = image_shape
  15. # Adjust boxes
  16. boxes[boxes[:, 0] < threshold, 0] = 0 # x1
  17. boxes[boxes[:, 1] < threshold, 1] = 0 # y1
  18. boxes[boxes[:, 2] > w - threshold, 2] = w # x2
  19. boxes[boxes[:, 3] > h - threshold, 3] = h # y2
  20. return boxes
  21. def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
  22. """
  23. Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
  24. Args:
  25. box1 (torch.Tensor): (4, )
  26. boxes (torch.Tensor): (n, 4)
  27. iou_thres (float): IoU threshold
  28. image_shape (tuple): (height, width)
  29. raw_output (bool): If True, return the raw IoU values instead of the indices
  30. Returns:
  31. high_iou_indices (torch.Tensor): Indices of boxes with IoU > thres
  32. """
  33. boxes = adjust_bboxes_to_image_border(boxes, image_shape)
  34. # Obtain coordinates for intersections
  35. x1 = torch.max(box1[0], boxes[:, 0])
  36. y1 = torch.max(box1[1], boxes[:, 1])
  37. x2 = torch.min(box1[2], boxes[:, 2])
  38. y2 = torch.min(box1[3], boxes[:, 3])
  39. # Compute the area of intersection
  40. intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
  41. # Compute the area of both individual boxes
  42. box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  43. box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  44. # Compute the area of union
  45. union = box1_area + box2_area - intersection
  46. # Compute the IoU
  47. iou = intersection / union # Should be shape (n, )
  48. if raw_output:
  49. return 0 if iou.numel() == 0 else iou
  50. # return indices of boxes with IoU > thres
  51. return torch.nonzero(iou > iou_thres).flatten()