predict.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. from ultralytics.engine.results import Results
  4. from ultralytics.models.fastsam.utils import bbox_iou
  5. from ultralytics.models.yolo.detect.predict import DetectionPredictor
  6. from ultralytics.utils import DEFAULT_CFG, ops
  7. class FastSAMPredictor(DetectionPredictor):
  8. """
  9. FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
  10. YOLO framework.
  11. This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
  12. It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
  13. for single-class segmentation.
  14. Attributes:
  15. cfg (dict): Configuration parameters for prediction.
  16. overrides (dict, optional): Optional parameter overrides for custom behavior.
  17. _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
  18. """
  19. def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
  20. """
  21. Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
  22. Args:
  23. cfg (dict): Configuration parameters for prediction.
  24. overrides (dict, optional): Optional parameter overrides for custom behavior.
  25. _callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
  26. """
  27. super().__init__(cfg, overrides, _callbacks)
  28. self.args.task = "segment"
  29. def postprocess(self, preds, img, orig_imgs):
  30. """
  31. Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
  32. size, and returns the final results.
  33. Args:
  34. preds (list): The raw output predictions from the model.
  35. img (torch.Tensor): The processed image tensor.
  36. orig_imgs (list | torch.Tensor): The original image or list of images.
  37. Returns:
  38. (list): A list of Results objects, each containing processed boxes, masks, and other metadata.
  39. """
  40. p = ops.non_max_suppression(
  41. preds[0],
  42. self.args.conf,
  43. self.args.iou,
  44. agnostic=self.args.agnostic_nms,
  45. max_det=self.args.max_det,
  46. nc=1, # set to 1 class since SAM has no class predictions
  47. classes=self.args.classes,
  48. )
  49. full_box = torch.zeros(p[0].shape[1], device=p[0].device)
  50. full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
  51. full_box = full_box.view(1, -1)
  52. critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
  53. if critical_iou_index.numel() != 0:
  54. full_box[0][4] = p[0][critical_iou_index][:, 4]
  55. full_box[0][6:] = p[0][critical_iou_index][:, 6:]
  56. p[0][critical_iou_index] = full_box
  57. if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
  58. orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
  59. results = []
  60. proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
  61. for i, pred in enumerate(p):
  62. orig_img = orig_imgs[i]
  63. img_path = self.batch[0][i]
  64. if not len(pred): # save empty boxes
  65. masks = None
  66. elif self.args.retina_masks:
  67. pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  68. masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
  69. else:
  70. masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
  71. pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
  72. results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
  73. return results