analytics.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. import warnings
  3. from itertools import cycle
  4. import cv2
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
  8. from matplotlib.figure import Figure
  9. class Analytics:
  10. """A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
  11. def __init__(
  12. self,
  13. type,
  14. writer,
  15. im0_shape,
  16. title="ultralytics",
  17. x_label="x",
  18. y_label="y",
  19. bg_color="white",
  20. fg_color="black",
  21. line_color="yellow",
  22. line_width=2,
  23. points_width=10,
  24. fontsize=13,
  25. view_img=False,
  26. save_img=True,
  27. max_points=50,
  28. ):
  29. """
  30. Initialize the Analytics class with various chart types.
  31. Args:
  32. type (str): Type of chart to initialize ('line', 'bar', 'pie', or 'area').
  33. writer (object): Video writer object to save the frames.
  34. im0_shape (tuple): Shape of the input image (width, height).
  35. title (str): Title of the chart.
  36. x_label (str): Label for the x-axis.
  37. y_label (str): Label for the y-axis.
  38. bg_color (str): Background color of the chart.
  39. fg_color (str): Foreground (text) color of the chart.
  40. line_color (str): Line color for line charts.
  41. line_width (int): Width of the lines in line charts.
  42. points_width (int): Width of line points highlighter
  43. fontsize (int): Font size for chart text.
  44. view_img (bool): Whether to display the image.
  45. save_img (bool): Whether to save the image.
  46. max_points (int): Specifies when to remove the oldest points in a graph for multiple lines.
  47. """
  48. self.bg_color = bg_color
  49. self.fg_color = fg_color
  50. self.view_img = view_img
  51. self.save_img = save_img
  52. self.title = title
  53. self.writer = writer
  54. self.max_points = max_points
  55. self.line_color = line_color
  56. self.x_label = x_label
  57. self.y_label = y_label
  58. self.points_width = points_width
  59. self.line_width = line_width
  60. self.fontsize = fontsize
  61. # Set figure size based on image shape
  62. figsize = (im0_shape[0] / 100, im0_shape[1] / 100)
  63. if type in {"line", "area"}:
  64. # Initialize line or area plot
  65. self.lines = {}
  66. self.fig = Figure(facecolor=self.bg_color, figsize=figsize)
  67. self.canvas = FigureCanvas(self.fig)
  68. self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)
  69. if type == "line":
  70. (self.line,) = self.ax.plot([], [], color=self.line_color, linewidth=self.line_width)
  71. elif type in {"bar", "pie"}:
  72. # Initialize bar or pie plot
  73. self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color)
  74. self.ax.set_facecolor(self.bg_color)
  75. color_palette = [
  76. (31, 119, 180),
  77. (255, 127, 14),
  78. (44, 160, 44),
  79. (214, 39, 40),
  80. (148, 103, 189),
  81. (140, 86, 75),
  82. (227, 119, 194),
  83. (127, 127, 127),
  84. (188, 189, 34),
  85. (23, 190, 207),
  86. ]
  87. self.color_palette = [(r / 255, g / 255, b / 255, 1) for r, g, b in color_palette]
  88. self.color_cycle = cycle(self.color_palette)
  89. self.color_mapping = {}
  90. # Ensure pie chart is circular
  91. self.ax.axis("equal") if type == "pie" else None
  92. # Set common axis properties
  93. self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
  94. self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=self.fontsize - 3)
  95. self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=self.fontsize - 3)
  96. self.ax.tick_params(axis="both", colors=self.fg_color)
  97. def update_area(self, frame_number, counts_dict):
  98. """
  99. Update the area graph with new data for multiple classes.
  100. Args:
  101. frame_number (int): The current frame number.
  102. counts_dict (dict): Dictionary with class names as keys and counts as values.
  103. """
  104. x_data = np.array([])
  105. y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
  106. if self.ax.lines:
  107. x_data = self.ax.lines[0].get_xdata()
  108. for line, key in zip(self.ax.lines, counts_dict.keys()):
  109. y_data_dict[key] = line.get_ydata()
  110. x_data = np.append(x_data, float(frame_number))
  111. max_length = len(x_data)
  112. for key in counts_dict.keys():
  113. y_data_dict[key] = np.append(y_data_dict[key], float(counts_dict[key]))
  114. if len(y_data_dict[key]) < max_length:
  115. y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant")
  116. # Remove the oldest points if the number of points exceeds max_points
  117. if len(x_data) > self.max_points:
  118. x_data = x_data[1:]
  119. for key in counts_dict.keys():
  120. y_data_dict[key] = y_data_dict[key][1:]
  121. self.ax.clear()
  122. colors = ["#E1FF25", "#0BDBEB", "#FF64DA", "#111F68", "#042AFF"]
  123. color_cycle = cycle(colors)
  124. for key, y_data in y_data_dict.items():
  125. color = next(color_cycle)
  126. self.ax.fill_between(x_data, y_data, color=color, alpha=0.6)
  127. self.ax.plot(
  128. x_data,
  129. y_data,
  130. color=color,
  131. linewidth=self.line_width,
  132. marker="o",
  133. markersize=self.points_width,
  134. label=f"{key} Data Points",
  135. )
  136. self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
  137. self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
  138. self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
  139. legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.fg_color)
  140. # Set legend text color
  141. for text in legend.get_texts():
  142. text.set_color(self.fg_color)
  143. self.canvas.draw()
  144. im0 = np.array(self.canvas.renderer.buffer_rgba())
  145. self.write_and_display(im0)
  146. def update_line(self, frame_number, total_counts):
  147. """
  148. Update the line graph with new data.
  149. Args:
  150. frame_number (int): The current frame number.
  151. total_counts (int): The total counts to plot.
  152. """
  153. # Update line graph data
  154. x_data = self.line.get_xdata()
  155. y_data = self.line.get_ydata()
  156. x_data = np.append(x_data, float(frame_number))
  157. y_data = np.append(y_data, float(total_counts))
  158. self.line.set_data(x_data, y_data)
  159. self.ax.relim()
  160. self.ax.autoscale_view()
  161. self.canvas.draw()
  162. im0 = np.array(self.canvas.renderer.buffer_rgba())
  163. self.write_and_display(im0)
  164. def update_multiple_lines(self, counts_dict, labels_list, frame_number):
  165. """
  166. Update the line graph with multiple classes.
  167. Args:
  168. counts_dict (int): Dictionary include each class counts.
  169. labels_list (int): list include each classes names.
  170. frame_number (int): The current frame number.
  171. """
  172. warnings.warn("Display is not supported for multiple lines, output will be stored normally!")
  173. for obj in labels_list:
  174. if obj not in self.lines:
  175. (line,) = self.ax.plot([], [], label=obj, marker="o", markersize=self.points_width)
  176. self.lines[obj] = line
  177. x_data = self.lines[obj].get_xdata()
  178. y_data = self.lines[obj].get_ydata()
  179. # Remove the initial point if the number of points exceeds max_points
  180. if len(x_data) >= self.max_points:
  181. x_data = np.delete(x_data, 0)
  182. y_data = np.delete(y_data, 0)
  183. x_data = np.append(x_data, float(frame_number)) # Ensure frame_number is converted to float
  184. y_data = np.append(y_data, float(counts_dict.get(obj, 0))) # Ensure total_count is converted to float
  185. self.lines[obj].set_data(x_data, y_data)
  186. self.ax.relim()
  187. self.ax.autoscale_view()
  188. self.ax.legend()
  189. self.canvas.draw()
  190. im0 = np.array(self.canvas.renderer.buffer_rgba())
  191. self.view_img = False # for multiple line view_img not supported yet, coming soon!
  192. self.write_and_display(im0)
  193. def write_and_display(self, im0):
  194. """
  195. Write and display the line graph
  196. Args:
  197. im0 (ndarray): Image for processing
  198. """
  199. im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
  200. cv2.imshow(self.title, im0) if self.view_img else None
  201. self.writer.write(im0) if self.save_img else None
  202. def update_bar(self, count_dict):
  203. """
  204. Update the bar graph with new data.
  205. Args:
  206. count_dict (dict): Dictionary containing the count data to plot.
  207. """
  208. # Update bar graph data
  209. self.ax.clear()
  210. self.ax.set_facecolor(self.bg_color)
  211. labels = list(count_dict.keys())
  212. counts = list(count_dict.values())
  213. # Map labels to colors
  214. for label in labels:
  215. if label not in self.color_mapping:
  216. self.color_mapping[label] = next(self.color_cycle)
  217. colors = [self.color_mapping[label] for label in labels]
  218. bars = self.ax.bar(labels, counts, color=colors)
  219. for bar, count in zip(bars, counts):
  220. self.ax.text(
  221. bar.get_x() + bar.get_width() / 2,
  222. bar.get_height(),
  223. str(count),
  224. ha="center",
  225. va="bottom",
  226. color=self.fg_color,
  227. )
  228. # Display and save the updated graph
  229. canvas = FigureCanvas(self.fig)
  230. canvas.draw()
  231. buf = canvas.buffer_rgba()
  232. im0 = np.asarray(buf)
  233. self.write_and_display(im0)
  234. def update_pie(self, classes_dict):
  235. """
  236. Update the pie chart with new data.
  237. Args:
  238. classes_dict (dict): Dictionary containing the class data to plot.
  239. """
  240. # Update pie chart data
  241. labels = list(classes_dict.keys())
  242. sizes = list(classes_dict.values())
  243. total = sum(sizes)
  244. percentages = [size / total * 100 for size in sizes]
  245. start_angle = 90
  246. self.ax.clear()
  247. # Create pie chart without labels inside the slices
  248. wedges, autotexts = self.ax.pie(sizes, autopct=None, startangle=start_angle, textprops={"color": self.fg_color})
  249. # Construct legend labels with percentages
  250. legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)]
  251. self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
  252. # Adjust layout to fit the legend
  253. self.fig.tight_layout()
  254. self.fig.subplots_adjust(left=0.1, right=0.75)
  255. # Display and save the updated chart
  256. im0 = self.fig.canvas.draw()
  257. im0 = np.array(self.fig.canvas.renderer.buffer_rgba())
  258. self.write_and_display(im0)
  259. if __name__ == "__main__":
  260. Analytics("line", writer=None, im0_shape=None)