val.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import torch
  3. from ultralytics.models.yolo.detect import DetectionValidator
  4. from ultralytics.utils import ops
  5. __all__ = ["NASValidator"]
  6. class NASValidator(DetectionValidator):
  7. """
  8. Ultralytics YOLO NAS Validator for object detection.
  9. Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
  10. generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
  11. ultimately producing the final detections.
  12. Attributes:
  13. args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
  14. lb (torch.Tensor): Optional tensor for multilabel NMS.
  15. Example:
  16. ```python
  17. from ultralytics import NAS
  18. model = NAS('yolo_nas_s')
  19. validator = model.validator
  20. # Assumes that raw_preds are available
  21. final_preds = validator.postprocess(raw_preds)
  22. ```
  23. Note:
  24. This class is generally not instantiated directly but is used internally within the `NAS` class.
  25. """
  26. def postprocess(self, preds_in):
  27. """Apply Non-maximum suppression to prediction outputs."""
  28. boxes = ops.xyxy2xywh(preds_in[0][0])
  29. preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
  30. return ops.non_max_suppression(
  31. preds,
  32. self.args.conf,
  33. self.args.iou,
  34. labels=self.lb,
  35. multi_label=False,
  36. agnostic=self.args.single_cls,
  37. max_det=self.args.max_det,
  38. max_time_img=0.5,
  39. )