ai_gym.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import cv2
  3. from ultralytics.utils.checks import check_imshow
  4. from ultralytics.utils.plotting import Annotator
  5. class AIGym:
  6. """A class to manage the gym steps of people in a real-time video stream based on their poses."""
  7. def __init__(
  8. self,
  9. kpts_to_check,
  10. line_thickness=2,
  11. view_img=False,
  12. pose_up_angle=145.0,
  13. pose_down_angle=90.0,
  14. pose_type="pullup",
  15. ):
  16. """
  17. Initializes the AIGym class with the specified parameters.
  18. Args:
  19. kpts_to_check (list): Indices of keypoints to check.
  20. line_thickness (int, optional): Thickness of the lines drawn. Defaults to 2.
  21. view_img (bool, optional): Flag to display the image. Defaults to False.
  22. pose_up_angle (float, optional): Angle threshold for the 'up' pose. Defaults to 145.0.
  23. pose_down_angle (float, optional): Angle threshold for the 'down' pose. Defaults to 90.0.
  24. pose_type (str, optional): Type of pose to detect ('pullup', 'pushup', 'abworkout'). Defaults to "pullup".
  25. """
  26. # Image and line thickness
  27. self.im0 = None
  28. self.tf = line_thickness
  29. # Keypoints and count information
  30. self.keypoints = None
  31. self.poseup_angle = pose_up_angle
  32. self.posedown_angle = pose_down_angle
  33. self.threshold = 0.001
  34. # Store stage, count and angle information
  35. self.angle = None
  36. self.count = None
  37. self.stage = None
  38. self.pose_type = pose_type
  39. self.kpts_to_check = kpts_to_check
  40. # Visual Information
  41. self.view_img = view_img
  42. self.annotator = None
  43. # Check if environment supports imshow
  44. self.env_check = check_imshow(warn=True)
  45. self.count = []
  46. self.angle = []
  47. self.stage = []
  48. def start_counting(self, im0, results):
  49. """
  50. Function used to count the gym steps.
  51. Args:
  52. im0 (ndarray): Current frame from the video stream.
  53. results (list): Pose estimation data.
  54. """
  55. self.im0 = im0
  56. if not len(results[0]):
  57. return self.im0
  58. if len(results[0]) > len(self.count):
  59. new_human = len(results[0]) - len(self.count)
  60. self.count += [0] * new_human
  61. self.angle += [0] * new_human
  62. self.stage += ["-"] * new_human
  63. self.keypoints = results[0].keypoints.data
  64. self.annotator = Annotator(im0, line_width=self.tf)
  65. for ind, k in enumerate(reversed(self.keypoints)):
  66. # Estimate angle and draw specific points based on pose type
  67. if self.pose_type in {"pushup", "pullup", "abworkout", "squat"}:
  68. self.angle[ind] = self.annotator.estimate_pose_angle(
  69. k[int(self.kpts_to_check[0])].cpu(),
  70. k[int(self.kpts_to_check[1])].cpu(),
  71. k[int(self.kpts_to_check[2])].cpu(),
  72. )
  73. self.im0 = self.annotator.draw_specific_points(k, self.kpts_to_check, shape=(640, 640), radius=10)
  74. # Check and update pose stages and counts based on angle
  75. if self.pose_type in {"abworkout", "pullup"}:
  76. if self.angle[ind] > self.poseup_angle:
  77. self.stage[ind] = "down"
  78. if self.angle[ind] < self.posedown_angle and self.stage[ind] == "down":
  79. self.stage[ind] = "up"
  80. self.count[ind] += 1
  81. elif self.pose_type in {"pushup", "squat"}:
  82. if self.angle[ind] > self.poseup_angle:
  83. self.stage[ind] = "up"
  84. if self.angle[ind] < self.posedown_angle and self.stage[ind] == "up":
  85. self.stage[ind] = "down"
  86. self.count[ind] += 1
  87. self.annotator.plot_angle_and_count_and_stage(
  88. angle_text=self.angle[ind],
  89. count_text=self.count[ind],
  90. stage_text=self.stage[ind],
  91. center_kpt=k[int(self.kpts_to_check[1])],
  92. )
  93. # Draw keypoints
  94. self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
  95. # Display the image if environment supports it and view_img is True
  96. if self.env_check and self.view_img:
  97. cv2.imshow("Ultralytics YOLOv8 AI GYM", self.im0)
  98. if cv2.waitKey(1) & 0xFF == ord("q"):
  99. return
  100. return self.im0
  101. if __name__ == "__main__":
  102. kpts_to_check = [0, 1, 2] # example keypoints
  103. aigym = AIGym(kpts_to_check)