Add multiple lines graph support in analytics 8.2.26 (#13214)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Muhammad Rizwan Munawar 2024-05-30 14:38:53 +05:00 committed by GitHub
parent fbd8fdb53e
commit f67f5d707b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 128 additions and 18 deletions

View file

@ -1,5 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import warnings
from itertools import cycle
import cv2
@ -27,6 +28,7 @@ class Analytics:
fontsize=13,
view_img=False,
save_img=True,
max_points=50,
):
"""
Initialize the Analytics class with various chart types.
@ -45,6 +47,7 @@ class Analytics:
fontsize (int): Font size for chart text.
view_img (bool): Whether to display the image.
save_img (bool): Whether to save the image.
max_points (int): Specifies when to remove the oldest points in a graph for multiple lines.
"""
self.bg_color = bg_color
@ -53,12 +56,14 @@ class Analytics:
self.save_img = save_img
self.title = title
self.writer = writer
self.max_points = max_points
# Set figure size based on image shape
figsize = (im0_shape[0] / 100, im0_shape[1] / 100)
if type == "line":
# Initialize line plot
self.lines = {}
fig = Figure(facecolor=self.bg_color, figsize=figsize)
self.canvas = FigureCanvas(fig)
self.ax = fig.add_subplot(111, facecolor=self.bg_color)
@ -112,9 +117,53 @@ class Analytics:
self.ax.autoscale_view()
self.canvas.draw()
im0 = np.array(self.canvas.renderer.buffer_rgba())
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
self.write_and_display_line(im0)
# Display and save the updated graph
def update_multiple_lines(self, counts_dict, labels_list, frame_number):
"""
Update the line graph with multiple classes.
Args:
counts_dict (int): Dictionary include each class counts.
labels_list (int): list include each classes names.
frame_number (int): The current frame number.
"""
warnings.warn("Display is not supported for multiple lines, output will be stored normally!")
for obj in labels_list:
if obj not in self.lines:
(line,) = self.ax.plot([], [], label=obj, marker="o", markersize=15)
self.lines[obj] = line
x_data = self.lines[obj].get_xdata()
y_data = self.lines[obj].get_ydata()
# Remove the initial point if the number of points exceeds max_points
if len(x_data) >= self.max_points:
x_data = np.delete(x_data, 0)
y_data = np.delete(y_data, 0)
x_data = np.append(x_data, float(frame_number)) # Ensure frame_number is converted to float
y_data = np.append(y_data, float(counts_dict.get(obj, 0))) # Ensure total_count is converted to float
self.lines[obj].set_data(x_data, y_data)
self.ax.relim()
self.ax.autoscale_view()
self.ax.legend()
self.canvas.draw()
im0 = np.array(self.canvas.renderer.buffer_rgba())
self.view_img = False # for multiple line view_img not supported yet, coming soon!
self.write_and_display_line(im0)
def write_and_display_line(self, im0):
"""
Write and display the line graph
Args:
im0 (ndarray): Image for processing
"""
# convert image to BGR format
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
cv2.imshow(self.title, im0) if self.view_img else None
self.writer.write(im0) if self.save_img else None