speed_estimation.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. from collections import defaultdict
  3. from time import time
  4. import cv2
  5. import numpy as np
  6. from ultralytics.utils.checks import check_imshow
  7. from ultralytics.utils.plotting import Annotator, colors
  8. class SpeedEstimator:
  9. """A class to estimate the speed of objects in a real-time video stream based on their tracks."""
  10. def __init__(self, names, reg_pts=None, view_img=False, line_thickness=2, region_thickness=5, spdl_dist_thresh=10):
  11. """
  12. Initializes the SpeedEstimator with the given parameters.
  13. Args:
  14. names (dict): Dictionary of class names.
  15. reg_pts (list, optional): List of region points for speed estimation. Defaults to [(20, 400), (1260, 400)].
  16. view_img (bool, optional): Whether to display the image with annotations. Defaults to False.
  17. line_thickness (int, optional): Thickness of the lines for drawing boxes and tracks. Defaults to 2.
  18. region_thickness (int, optional): Thickness of the region lines. Defaults to 5.
  19. spdl_dist_thresh (int, optional): Distance threshold for speed calculation. Defaults to 10.
  20. """
  21. # Visual & image information
  22. self.im0 = None
  23. self.annotator = None
  24. self.view_img = view_img
  25. # Region information
  26. self.reg_pts = reg_pts if reg_pts is not None else [(20, 400), (1260, 400)]
  27. self.region_thickness = region_thickness
  28. # Tracking information
  29. self.clss = None
  30. self.names = names
  31. self.boxes = None
  32. self.trk_ids = None
  33. self.trk_pts = None
  34. self.line_thickness = line_thickness
  35. self.trk_history = defaultdict(list)
  36. # Speed estimation information
  37. self.current_time = 0
  38. self.dist_data = {}
  39. self.trk_idslist = []
  40. self.spdl_dist_thresh = spdl_dist_thresh
  41. self.trk_previous_times = {}
  42. self.trk_previous_points = {}
  43. # Check if the environment supports imshow
  44. self.env_check = check_imshow(warn=True)
  45. def extract_tracks(self, tracks):
  46. """
  47. Extracts results from the provided tracking data.
  48. Args:
  49. tracks (list): List of tracks obtained from the object tracking process.
  50. """
  51. self.boxes = tracks[0].boxes.xyxy.cpu()
  52. self.clss = tracks[0].boxes.cls.cpu().tolist()
  53. self.trk_ids = tracks[0].boxes.id.int().cpu().tolist()
  54. def store_track_info(self, track_id, box):
  55. """
  56. Stores track data.
  57. Args:
  58. track_id (int): Object track id.
  59. box (list): Object bounding box data.
  60. Returns:
  61. (list): Updated tracking history for the given track_id.
  62. """
  63. track = self.trk_history[track_id]
  64. bbox_center = (float((box[0] + box[2]) / 2), float((box[1] + box[3]) / 2))
  65. track.append(bbox_center)
  66. if len(track) > 30:
  67. track.pop(0)
  68. self.trk_pts = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
  69. return track
  70. def plot_box_and_track(self, track_id, box, cls, track):
  71. """
  72. Plots track and bounding box.
  73. Args:
  74. track_id (int): Object track id.
  75. box (list): Object bounding box data.
  76. cls (str): Object class name.
  77. track (list): Tracking history for drawing tracks path.
  78. """
  79. speed_label = f"{int(self.dist_data[track_id])} km/h" if track_id in self.dist_data else self.names[int(cls)]
  80. bbox_color = colors(int(track_id)) if track_id in self.dist_data else (255, 0, 255)
  81. self.annotator.box_label(box, speed_label, bbox_color)
  82. cv2.polylines(self.im0, [self.trk_pts], isClosed=False, color=(0, 255, 0), thickness=1)
  83. cv2.circle(self.im0, (int(track[-1][0]), int(track[-1][1])), 5, bbox_color, -1)
  84. def calculate_speed(self, trk_id, track):
  85. """
  86. Calculates the speed of an object.
  87. Args:
  88. trk_id (int): Object track id.
  89. track (list): Tracking history for drawing tracks path.
  90. """
  91. if not self.reg_pts[0][0] < track[-1][0] < self.reg_pts[1][0]:
  92. return
  93. if self.reg_pts[1][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[1][1] + self.spdl_dist_thresh:
  94. direction = "known"
  95. elif self.reg_pts[0][1] - self.spdl_dist_thresh < track[-1][1] < self.reg_pts[0][1] + self.spdl_dist_thresh:
  96. direction = "known"
  97. else:
  98. direction = "unknown"
  99. if self.trk_previous_times.get(trk_id) != 0 and direction != "unknown" and trk_id not in self.trk_idslist:
  100. self.trk_idslist.append(trk_id)
  101. time_difference = time() - self.trk_previous_times[trk_id]
  102. if time_difference > 0:
  103. dist_difference = np.abs(track[-1][1] - self.trk_previous_points[trk_id][1])
  104. speed = dist_difference / time_difference
  105. self.dist_data[trk_id] = speed
  106. self.trk_previous_times[trk_id] = time()
  107. self.trk_previous_points[trk_id] = track[-1]
  108. def estimate_speed(self, im0, tracks, region_color=(255, 0, 0)):
  109. """
  110. Estimates the speed of objects based on tracking data.
  111. Args:
  112. im0 (ndarray): Image.
  113. tracks (list): List of tracks obtained from the object tracking process.
  114. region_color (tuple, optional): Color to use when drawing regions. Defaults to (255, 0, 0).
  115. Returns:
  116. (ndarray): The image with annotated boxes and tracks.
  117. """
  118. self.im0 = im0
  119. if tracks[0].boxes.id is None:
  120. if self.view_img and self.env_check:
  121. self.display_frames()
  122. return im0
  123. self.extract_tracks(tracks)
  124. self.annotator = Annotator(self.im0, line_width=self.line_thickness)
  125. self.annotator.draw_region(reg_pts=self.reg_pts, color=region_color, thickness=self.region_thickness)
  126. for box, trk_id, cls in zip(self.boxes, self.trk_ids, self.clss):
  127. track = self.store_track_info(trk_id, box)
  128. if trk_id not in self.trk_previous_times:
  129. self.trk_previous_times[trk_id] = 0
  130. self.plot_box_and_track(trk_id, box, cls, track)
  131. self.calculate_speed(trk_id, track)
  132. if self.view_img and self.env_check:
  133. self.display_frames()
  134. return im0
  135. def display_frames(self):
  136. """Displays the current frame."""
  137. cv2.imshow("Ultralytics Speed Estimation", self.im0)
  138. if cv2.waitKey(1) & 0xFF == ord("q"):
  139. return
  140. if __name__ == "__main__":
  141. names = {0: "person", 1: "car"} # example class names
  142. speed_estimator = SpeedEstimator(names)