Add area chart in analytics (#13391)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
5376b1a42e
commit
89108513c4
2 changed files with 157 additions and 40 deletions
|
|
@ -229,27 +229,80 @@ This guide provides a comprehensive overview of three fundamental types of data
|
|||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
=== "Area chart"
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from ultralytics import YOLO, solutions
|
||||
model = YOLO("yolov8s.pt")
|
||||
|
||||
cap = cv2.VideoCapture("path/to/video/file.mp4")
|
||||
assert cap.isOpened(), "Error reading video file"
|
||||
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
|
||||
|
||||
out = cv2.VideoWriter("area_plot.avi", cv2.VideoWriter_fourcc(*"MJPG"), fps, (w, h))
|
||||
|
||||
analytics = solutions.Analytics(
|
||||
type="area",
|
||||
writer=out,
|
||||
im0_shape=(w, h),
|
||||
view_img=True,
|
||||
)
|
||||
|
||||
clswise_count = {}
|
||||
frame_count = 0
|
||||
|
||||
while cap.isOpened():
|
||||
success, frame = cap.read()
|
||||
if success:
|
||||
|
||||
frame_count += 1
|
||||
results = model.track(frame, persist=True, verbose=True)
|
||||
|
||||
if results[0].boxes.id is not None:
|
||||
boxes = results[0].boxes.xyxy.cpu()
|
||||
clss = results[0].boxes.cls.cpu().tolist()
|
||||
|
||||
for box, cls in zip(boxes, clss):
|
||||
if model.names[int(cls)] in clswise_count:
|
||||
clswise_count[model.names[int(cls)]] += 1
|
||||
else:
|
||||
clswise_count[model.names[int(cls)]] = 1
|
||||
|
||||
analytics.update_area(frame_count, clswise_count)
|
||||
clswise_count = {}
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
cap.release()
|
||||
out.release()
|
||||
cv2.destroyAllWindows()
|
||||
```
|
||||
|
||||
### Argument `Analytics`
|
||||
|
||||
Here's a table with the `Analytics` arguments:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|--------------|-------------------|---------------|----------------------------------------------------------------------------------|
|
||||
| `type` | `str` | `None` | Type of data or object. |
|
||||
| `im0_shape` | `tuple` | `None` | Shape of the initial image. |
|
||||
| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. |
|
||||
| `title` | `str` | `ultralytics` | Title for the visualization. |
|
||||
| `x_label` | `str` | `x` | Label for the x-axis. |
|
||||
| `y_label` | `str` | `y` | Label for the y-axis. |
|
||||
| `bg_color` | `str` | `white` | Background color. |
|
||||
| `fg_color` | `str` | `black` | Foreground color. |
|
||||
| `line_color` | `str` | `yellow` | Color of the lines. |
|
||||
| `line_width` | `int` | `2` | Width of the lines. |
|
||||
| `fontsize` | `int` | `13` | Font size for text. |
|
||||
| `view_img` | `bool` | `False` | Flag to display the image or video. |
|
||||
| `save_img` | `bool` | `True` | Flag to save the image or video. |
|
||||
| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. |
|
||||
| Name | Type | Default | Description |
|
||||
|----------------|-------------------|---------------|----------------------------------------------------------------------------------|
|
||||
| `type` | `str` | `None` | Type of data or object. |
|
||||
| `im0_shape` | `tuple` | `None` | Shape of the initial image. |
|
||||
| `writer` | `cv2.VideoWriter` | `None` | Object for writing video files. |
|
||||
| `title` | `str` | `ultralytics` | Title for the visualization. |
|
||||
| `x_label` | `str` | `x` | Label for the x-axis. |
|
||||
| `y_label` | `str` | `y` | Label for the y-axis. |
|
||||
| `bg_color` | `str` | `white` | Background color. |
|
||||
| `fg_color` | `str` | `black` | Foreground color. |
|
||||
| `line_color` | `str` | `yellow` | Color of the lines. |
|
||||
| `line_width` | `int` | `2` | Width of the lines. |
|
||||
| `fontsize` | `int` | `13` | Font size for text. |
|
||||
| `view_img` | `bool` | `False` | Flag to display the image or video. |
|
||||
| `save_img` | `bool` | `True` | Flag to save the image or video. |
|
||||
| `max_points` | `int` | `50` | For multiple lines, total points drawn on frame, before deleting initial points. |
|
||||
| `points_width` | `int` | `15` | Width of line points highlighter. |
|
||||
|
||||
### Arguments `model.track`
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from matplotlib.figure import Figure
|
|||
|
||||
|
||||
class Analytics:
|
||||
"""A class to create and update various types of charts (line, bar, pie) for visual analytics."""
|
||||
"""A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -25,6 +25,7 @@ class Analytics:
|
|||
fg_color="black",
|
||||
line_color="yellow",
|
||||
line_width=2,
|
||||
points_width=10,
|
||||
fontsize=13,
|
||||
view_img=False,
|
||||
save_img=True,
|
||||
|
|
@ -34,7 +35,7 @@ class Analytics:
|
|||
Initialize the Analytics class with various chart types.
|
||||
|
||||
Args:
|
||||
type (str): Type of chart to initialize ('line', 'bar', or 'pie').
|
||||
type (str): Type of chart to initialize ('line', 'bar', 'pie', or 'area').
|
||||
writer (object): Video writer object to save the frames.
|
||||
im0_shape (tuple): Shape of the input image (width, height).
|
||||
title (str): Title of the chart.
|
||||
|
|
@ -44,6 +45,7 @@ class Analytics:
|
|||
fg_color (str): Foreground (text) color of the chart.
|
||||
line_color (str): Line color for line charts.
|
||||
line_width (int): Width of the lines in line charts.
|
||||
points_width (int): Width of line points highlighter
|
||||
fontsize (int): Font size for chart text.
|
||||
view_img (bool): Whether to display the image.
|
||||
save_img (bool): Whether to save the image.
|
||||
|
|
@ -57,17 +59,24 @@ class Analytics:
|
|||
self.title = title
|
||||
self.writer = writer
|
||||
self.max_points = max_points
|
||||
self.line_color = line_color
|
||||
self.x_label = x_label
|
||||
self.y_label = y_label
|
||||
self.points_width = points_width
|
||||
self.line_width = line_width
|
||||
self.fontsize = fontsize
|
||||
|
||||
# Set figure size based on image shape
|
||||
figsize = (im0_shape[0] / 100, im0_shape[1] / 100)
|
||||
|
||||
if type == "line":
|
||||
# Initialize line plot
|
||||
if type in {"line", "area"}:
|
||||
# Initialize line or area plot
|
||||
self.lines = {}
|
||||
fig = Figure(facecolor=self.bg_color, figsize=figsize)
|
||||
self.canvas = FigureCanvas(fig)
|
||||
self.ax = fig.add_subplot(111, facecolor=self.bg_color)
|
||||
(self.line,) = self.ax.plot([], [], color=line_color, linewidth=line_width)
|
||||
self.fig = Figure(facecolor=self.bg_color, figsize=figsize)
|
||||
self.canvas = FigureCanvas(self.fig)
|
||||
self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)
|
||||
if type == "line":
|
||||
(self.line,) = self.ax.plot([], [], color=self.line_color, linewidth=self.line_width)
|
||||
|
||||
elif type in {"bar", "pie"}:
|
||||
# Initialize bar or pie plot
|
||||
|
|
@ -93,11 +102,73 @@ class Analytics:
|
|||
self.ax.axis("equal") if type == "pie" else None
|
||||
|
||||
# Set common axis properties
|
||||
self.ax.set_title(self.title, color=self.fg_color, fontsize=fontsize)
|
||||
self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=fontsize - 3)
|
||||
self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=fontsize - 3)
|
||||
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
|
||||
self.ax.set_xlabel(x_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
self.ax.set_ylabel(y_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
self.ax.tick_params(axis="both", colors=self.fg_color)
|
||||
|
||||
def update_area(self, frame_number, counts_dict):
|
||||
"""
|
||||
Update the area graph with new data for multiple classes.
|
||||
|
||||
Args:
|
||||
frame_number (int): The current frame number.
|
||||
counts_dict (dict): Dictionary with class names as keys and counts as values.
|
||||
"""
|
||||
|
||||
x_data = np.array([])
|
||||
y_data_dict = {key: np.array([]) for key in counts_dict.keys()}
|
||||
|
||||
if self.ax.lines:
|
||||
x_data = self.ax.lines[0].get_xdata()
|
||||
for line, key in zip(self.ax.lines, counts_dict.keys()):
|
||||
y_data_dict[key] = line.get_ydata()
|
||||
|
||||
x_data = np.append(x_data, float(frame_number))
|
||||
max_length = len(x_data)
|
||||
|
||||
for key in counts_dict.keys():
|
||||
y_data_dict[key] = np.append(y_data_dict[key], float(counts_dict[key]))
|
||||
if len(y_data_dict[key]) < max_length:
|
||||
y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])), "constant")
|
||||
|
||||
# Remove the oldest points if the number of points exceeds max_points
|
||||
if len(x_data) > self.max_points:
|
||||
x_data = x_data[1:]
|
||||
for key in counts_dict.keys():
|
||||
y_data_dict[key] = y_data_dict[key][1:]
|
||||
|
||||
self.ax.clear()
|
||||
|
||||
colors = ["#E1FF25", "#0BDBEB", "#FF64DA", "#111F68", "#042AFF"]
|
||||
color_cycle = cycle(colors)
|
||||
|
||||
for key, y_data in y_data_dict.items():
|
||||
color = next(color_cycle)
|
||||
self.ax.fill_between(x_data, y_data, color=color, alpha=0.6)
|
||||
self.ax.plot(
|
||||
x_data,
|
||||
y_data,
|
||||
color=color,
|
||||
linewidth=self.line_width,
|
||||
marker="o",
|
||||
markersize=self.points_width,
|
||||
label=f"{key} Data Points",
|
||||
)
|
||||
|
||||
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
|
||||
self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.fg_color)
|
||||
|
||||
# Set legend text color
|
||||
for text in legend.get_texts():
|
||||
text.set_color(self.fg_color)
|
||||
|
||||
self.canvas.draw()
|
||||
im0 = np.array(self.canvas.renderer.buffer_rgba())
|
||||
self.write_and_display(im0)
|
||||
|
||||
def update_line(self, frame_number, total_counts):
|
||||
"""
|
||||
Update the line graph with new data.
|
||||
|
|
@ -117,7 +188,7 @@ class Analytics:
|
|||
self.ax.autoscale_view()
|
||||
self.canvas.draw()
|
||||
im0 = np.array(self.canvas.renderer.buffer_rgba())
|
||||
self.write_and_display_line(im0)
|
||||
self.write_and_display(im0)
|
||||
|
||||
def update_multiple_lines(self, counts_dict, labels_list, frame_number):
|
||||
"""
|
||||
|
|
@ -131,7 +202,7 @@ class Analytics:
|
|||
warnings.warn("Display is not supported for multiple lines, output will be stored normally!")
|
||||
for obj in labels_list:
|
||||
if obj not in self.lines:
|
||||
(line,) = self.ax.plot([], [], label=obj, marker="o", markersize=15)
|
||||
(line,) = self.ax.plot([], [], label=obj, marker="o", markersize=self.points_width)
|
||||
self.lines[obj] = line
|
||||
|
||||
x_data = self.lines[obj].get_xdata()
|
||||
|
|
@ -153,16 +224,14 @@ class Analytics:
|
|||
|
||||
im0 = np.array(self.canvas.renderer.buffer_rgba())
|
||||
self.view_img = False # for multiple line view_img not supported yet, coming soon!
|
||||
self.write_and_display_line(im0)
|
||||
self.write_and_display(im0)
|
||||
|
||||
def write_and_display_line(self, im0):
|
||||
def write_and_display(self, im0):
|
||||
"""
|
||||
Write and display the line graph
|
||||
Args:
|
||||
im0 (ndarray): Image for processing
|
||||
"""
|
||||
|
||||
# convert image to BGR format
|
||||
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
|
||||
cv2.imshow(self.title, im0) if self.view_img else None
|
||||
self.writer.write(im0) if self.save_img else None
|
||||
|
|
@ -204,10 +273,7 @@ class Analytics:
|
|||
canvas.draw()
|
||||
buf = canvas.buffer_rgba()
|
||||
im0 = np.asarray(buf)
|
||||
im0 = cv2.cvtColor(im0, cv2.COLOR_RGBA2BGR)
|
||||
|
||||
self.writer.write(im0) if self.save_img else None
|
||||
cv2.imshow(self.title, im0) if self.view_img else None
|
||||
self.write_and_display(im0)
|
||||
|
||||
def update_pie(self, classes_dict):
|
||||
"""
|
||||
|
|
@ -239,9 +305,7 @@ class Analytics:
|
|||
# Display and save the updated chart
|
||||
im0 = self.fig.canvas.draw()
|
||||
im0 = np.array(self.fig.canvas.renderer.buffer_rgba())
|
||||
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
|
||||
self.writer.write(im0) if self.save_img else None
|
||||
cv2.imshow(self.title, im0) if self.view_img else None
|
||||
self.write_and_display(im0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue