ultralytics 8.3.16 PyTorch 2.5.0 support (#16998)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: RizwanMunawar <chr043416@gmail.com> Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
This commit is contained in:
parent
ef28f1078c
commit
8d7d1fe390
17 changed files with 570 additions and 144 deletions
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
|
|
@ -18,7 +18,7 @@ jobs:
|
|||
name: Publish
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
id-token: write # for PyPI trusted publishing
|
||||
id-token: write # for PyPI trusted publishing
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
|
|
|||
|
|
@ -76,6 +76,9 @@
|
|||
79740115+0xSynapse@users.noreply.github.com:
|
||||
avatar: https://avatars.githubusercontent.com/u/79740115?v=4
|
||||
username: 0xSynapse
|
||||
91465467+lalayants@users.noreply.github.com:
|
||||
avatar: https://avatars.githubusercontent.com/u/91465467?v=4
|
||||
username: lalayants
|
||||
Francesco.mttl@gmail.com:
|
||||
avatar: https://avatars.githubusercontent.com/u/3855193?v=4
|
||||
username: ambitious-octopus
|
||||
|
|
|
|||
|
|
@ -555,6 +555,7 @@ nav:
|
|||
- utils: reference/nn/modules/utils.md
|
||||
- tasks: reference/nn/tasks.md
|
||||
- solutions:
|
||||
- solutions: reference/solutions/solutions.md
|
||||
- ai_gym: reference/solutions/ai_gym.md
|
||||
- analytics: reference/solutions/analytics.md
|
||||
- distance_calculation: reference/solutions/distance_calculation.md
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ build-backend = "setuptools.build_meta"
|
|||
[project]
|
||||
name = "ultralytics"
|
||||
dynamic = ["version"]
|
||||
description = "Ultralytics YOLO for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification."
|
||||
description = "Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = { "text" = "AGPL-3.0" }
|
||||
|
|
|
|||
|
|
@ -17,10 +17,15 @@ def test_major_solutions():
|
|||
cap = cv2.VideoCapture("solutions_ci_demo.mp4")
|
||||
assert cap.isOpened(), "Error reading video file"
|
||||
region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
|
||||
counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False)
|
||||
heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False)
|
||||
speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False)
|
||||
queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False)
|
||||
counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter
|
||||
heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps
|
||||
speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager
|
||||
queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation
|
||||
line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics
|
||||
pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics
|
||||
bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics
|
||||
area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics
|
||||
frame_count = 0 # Required for analytics
|
||||
while cap.isOpened():
|
||||
success, im0 = cap.read()
|
||||
if not success:
|
||||
|
|
@ -30,24 +35,23 @@ def test_major_solutions():
|
|||
_ = heatmap.generate_heatmap(original_im0.copy())
|
||||
_ = speed.estimate_speed(original_im0.copy())
|
||||
_ = queue.process_queue(original_im0.copy())
|
||||
_ = line_analytics.process_data(original_im0.copy(), frame_count)
|
||||
_ = pie_analytics.process_data(original_im0.copy(), frame_count)
|
||||
_ = bar_analytics.process_data(original_im0.copy(), frame_count)
|
||||
_ = area_analytics.process_data(original_im0.copy(), frame_count)
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_aigym():
|
||||
"""Test the workouts monitoring solution."""
|
||||
# Test workouts monitoring
|
||||
safe_download(url=WORKOUTS_SOLUTION_DEMO)
|
||||
cap = cv2.VideoCapture("solution_ci_pose_demo.mp4")
|
||||
assert cap.isOpened(), "Error reading video file"
|
||||
gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13])
|
||||
while cap.isOpened():
|
||||
success, im0 = cap.read()
|
||||
cap1 = cv2.VideoCapture("solution_ci_pose_demo.mp4")
|
||||
assert cap1.isOpened(), "Error reading video file"
|
||||
gym = solutions.AIGym(line_width=2, kpts=[5, 11, 13], show=False)
|
||||
while cap1.isOpened():
|
||||
success, im0 = cap1.read()
|
||||
if not success:
|
||||
break
|
||||
_ = gym.monitor(im0)
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
cap1.release()
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.15"
|
||||
__version__ = "8.3.16"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -13,9 +13,6 @@ from tqdm import tqdm
|
|||
from ultralytics.data.utils import exif_size, img2label_paths
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements("shapely")
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
|
||||
def bbox_iof(polygon1, bbox2, eps=1e-6):
|
||||
"""
|
||||
|
|
@ -33,6 +30,9 @@ def bbox_iof(polygon1, bbox2, eps=1e-6):
|
|||
Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
Bounding box format: [x_min, y_min, x_max, y_max].
|
||||
"""
|
||||
check_requirements("shapely")
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
polygon1 = polygon1.reshape(-1, 4, 2)
|
||||
lt_point = np.min(polygon1, axis=-2) # left-top
|
||||
rb_point = np.max(polygon1, axis=-2) # right-bottom
|
||||
|
|
|
|||
|
|
@ -1,16 +1,40 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils.plotting import Annotator
|
||||
|
||||
|
||||
class AIGym(BaseSolution):
|
||||
"""A class to manage the gym steps of people in a real-time video stream based on their poses."""
|
||||
"""
|
||||
A class to manage gym steps of people in a real-time video stream based on their poses.
|
||||
|
||||
This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
|
||||
repetitions of exercises based on predefined angle thresholds for up and down positions.
|
||||
|
||||
Attributes:
|
||||
count (List[int]): Repetition counts for each detected person.
|
||||
angle (List[float]): Current angle of the tracked body part for each person.
|
||||
stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person.
|
||||
initial_stage (str | None): Initial stage of the exercise.
|
||||
up_angle (float): Angle threshold for considering the 'up' position of an exercise.
|
||||
down_angle (float): Angle threshold for considering the 'down' position of an exercise.
|
||||
kpts (List[int]): Indices of keypoints used for angle calculation.
|
||||
lw (int): Line width for drawing annotations.
|
||||
annotator (Annotator): Object for drawing annotations on the image.
|
||||
|
||||
Methods:
|
||||
monitor: Processes a frame to detect poses, calculate angles, and count repetitions.
|
||||
|
||||
Examples:
|
||||
>>> gym = AIGym(model="yolov8n-pose.pt")
|
||||
>>> image = cv2.imread("gym_scene.jpg")
|
||||
>>> processed_image = gym.monitor(image)
|
||||
>>> cv2.imshow("Processed Image", processed_image)
|
||||
>>> cv2.waitKey(0)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialization function for AiGYM class, a child class of BaseSolution class, can be used for workouts
|
||||
monitoring.
|
||||
"""
|
||||
"""Initializes AIGym for workout monitoring using pose estimation and predefined angles."""
|
||||
# Check if the model name ends with '-pose'
|
||||
if "model" in kwargs and "-pose" not in kwargs["model"]:
|
||||
kwargs["model"] = "yolo11n-pose.pt"
|
||||
|
|
@ -31,12 +55,22 @@ class AIGym(BaseSolution):
|
|||
|
||||
def monitor(self, im0):
|
||||
"""
|
||||
Monitor the workouts using Ultralytics YOLO Pose Model: https://docs.ultralytics.com/tasks/pose/.
|
||||
Monitors workouts using Ultralytics YOLO Pose Model.
|
||||
|
||||
This function processes an input image to track and analyze human poses for workout monitoring. It uses
|
||||
the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
|
||||
angle thresholds.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image that will be used for processing
|
||||
Returns
|
||||
im0 (ndarray): The processed image for more usage
|
||||
im0 (ndarray): Input image for processing.
|
||||
|
||||
Returns:
|
||||
(ndarray): Processed image with annotations for workout monitoring.
|
||||
|
||||
Examples:
|
||||
>>> gym = AIGym()
|
||||
>>> image = cv2.imread("workout.jpg")
|
||||
>>> processed_image = gym.monitor(image)
|
||||
"""
|
||||
# Extract tracks
|
||||
tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])[0]
|
||||
|
|
|
|||
|
|
@ -12,10 +12,41 @@ from ultralytics.solutions.solutions import BaseSolution # Import a parent clas
|
|||
|
||||
|
||||
class Analytics(BaseSolution):
|
||||
"""A class to create and update various types of charts (line, bar, pie, area) for visual analytics."""
|
||||
"""
|
||||
A class for creating and updating various types of charts for visual analytics.
|
||||
|
||||
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
|
||||
based on object detection and tracking data.
|
||||
|
||||
Attributes:
|
||||
type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
|
||||
x_label (str): Label for the x-axis.
|
||||
y_label (str): Label for the y-axis.
|
||||
bg_color (str): Background color of the chart frame.
|
||||
fg_color (str): Foreground color of the chart frame.
|
||||
title (str): Title of the chart window.
|
||||
max_points (int): Maximum number of data points to display on the chart.
|
||||
fontsize (int): Font size for text display.
|
||||
color_cycle (cycle): Cyclic iterator for chart colors.
|
||||
total_counts (int): Total count of detected objects (used for line charts).
|
||||
clswise_count (Dict[str, int]): Dictionary for class-wise object counts.
|
||||
fig (Figure): Matplotlib figure object for the chart.
|
||||
ax (Axes): Matplotlib axes object for the chart.
|
||||
canvas (FigureCanvas): Canvas for rendering the chart.
|
||||
|
||||
Methods:
|
||||
process_data: Processes image data and updates the chart.
|
||||
update_graph: Updates the chart with new data points.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics(analytics_type="line")
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> processed_frame = analytics.process_data(frame, frame_number=1)
|
||||
>>> cv2.imshow("Analytics", processed_frame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the Analytics class with various chart types."""
|
||||
"""Initialize Analytics class with various chart types for visual data representation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.type = self.CFG["analytics_type"] # extract type of analytics
|
||||
|
|
@ -31,8 +62,8 @@ class Analytics(BaseSolution):
|
|||
figsize = (19.2, 10.8) # Set output image size 1920 * 1080
|
||||
self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
|
||||
|
||||
self.total_counts = 0 # count variable for storing total counts i.e for line
|
||||
self.clswise_count = {} # dictionary for classwise counts
|
||||
self.total_counts = 0 # count variable for storing total counts i.e. for line
|
||||
self.clswise_count = {} # dictionary for class-wise counts
|
||||
|
||||
# Ensure line and area chart
|
||||
if self.type in {"line", "area"}:
|
||||
|
|
@ -48,15 +79,28 @@ class Analytics(BaseSolution):
|
|||
self.canvas = FigureCanvas(self.fig) # Set common axis properties
|
||||
self.ax.set_facecolor(self.bg_color)
|
||||
self.color_mapping = {}
|
||||
self.ax.axis("equal") if self.type == "pie" else None # Ensure pie chart is circular
|
||||
|
||||
if self.type == "pie": # Ensure pie chart is circular
|
||||
self.ax.axis("equal")
|
||||
|
||||
def process_data(self, im0, frame_number):
|
||||
"""
|
||||
Process the image data, run object tracking.
|
||||
Processes image data and runs object tracking to update analytics charts.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): Input image for processing.
|
||||
frame_number (int): Video frame # for plotting the data.
|
||||
im0 (np.ndarray): Input image for processing.
|
||||
frame_number (int): Video frame number for plotting the data.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed image with updated analytics chart.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If an unsupported chart type is specified.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics(analytics_type="line")
|
||||
>>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
>>> processed_frame = analytics.process_data(frame, frame_number=1)
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
|
||||
|
|
@ -79,13 +123,22 @@ class Analytics(BaseSolution):
|
|||
|
||||
def update_graph(self, frame_number, count_dict=None, plot="line"):
|
||||
"""
|
||||
Update the graph (line or area) with new data for single or multiple classes.
|
||||
Updates the graph with new data for single or multiple classes.
|
||||
|
||||
Args:
|
||||
frame_number (int): The current frame number.
|
||||
count_dict (dict, optional): Dictionary with class names as keys and counts as values for multiple classes.
|
||||
If None, updates a single line graph.
|
||||
plot (str): Type of the plot i.e. line, bar or area.
|
||||
count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple
|
||||
classes. If None, updates a single line graph.
|
||||
plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Updated image containing the graph.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics()
|
||||
>>> frame_number = 10
|
||||
>>> count_dict = {"person": 5, "car": 3}
|
||||
>>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar")
|
||||
"""
|
||||
if count_dict is None:
|
||||
# Single line update
|
||||
|
|
|
|||
|
|
@ -4,15 +4,41 @@ import math
|
|||
|
||||
import cv2
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
class DistanceCalculation(BaseSolution):
|
||||
"""A class to calculate distance between two objects in a real-time video stream based on their tracks."""
|
||||
"""
|
||||
A class to calculate distance between two objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
|
||||
between them in a video stream using YOLO object detection and tracking.
|
||||
|
||||
Attributes:
|
||||
left_mouse_count (int): Counter for left mouse button clicks.
|
||||
selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs.
|
||||
annotator (Annotator): An instance of the Annotator class for drawing on the image.
|
||||
boxes (List[List[float]]): List of bounding boxes for detected objects.
|
||||
track_ids (List[int]): List of track IDs for detected objects.
|
||||
clss (List[int]): List of class indices for detected objects.
|
||||
names (List[str]): List of class names that the model can detect.
|
||||
centroids (List[List[int]]): List to store centroids of selected bounding boxes.
|
||||
|
||||
Methods:
|
||||
mouse_event_for_distance: Handles mouse events for selecting objects in the video stream.
|
||||
calculate: Processes video frames and calculates the distance between selected objects.
|
||||
|
||||
Examples:
|
||||
>>> distance_calc = DistanceCalculation()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_frame = distance_calc.calculate(frame)
|
||||
>>> cv2.imshow("Distance Calculation", processed_frame)
|
||||
>>> cv2.waitKey(0)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the DistanceCalculation class with the given parameters."""
|
||||
"""Initializes the DistanceCalculation class for measuring object distances in video streams."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Mouse event information
|
||||
|
|
@ -21,14 +47,18 @@ class DistanceCalculation(BaseSolution):
|
|||
|
||||
def mouse_event_for_distance(self, event, x, y, flags, param):
|
||||
"""
|
||||
Handles mouse events to select regions in a real-time video stream.
|
||||
Handles mouse events to select regions in a real-time video stream for distance calculation.
|
||||
|
||||
Args:
|
||||
event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN, etc.).
|
||||
event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
|
||||
x (int): X-coordinate of the mouse pointer.
|
||||
y (int): Y-coordinate of the mouse pointer.
|
||||
flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY, etc.).
|
||||
param (dict): Additional parameters passed to the function.
|
||||
flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
|
||||
param (Dict): Additional parameters passed to the function.
|
||||
|
||||
Examples:
|
||||
>>> # Assuming 'dc' is an instance of DistanceCalculation
|
||||
>>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
|
||||
"""
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
self.left_mouse_count += 1
|
||||
|
|
@ -43,13 +73,23 @@ class DistanceCalculation(BaseSolution):
|
|||
|
||||
def calculate(self, im0):
|
||||
"""
|
||||
Processes the video frame and calculates the distance between two bounding boxes.
|
||||
Processes a video frame and calculates the distance between two selected bounding boxes.
|
||||
|
||||
This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
|
||||
between two user-selected objects if they have been chosen.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The image frame.
|
||||
im0 (numpy.ndarray): The input image frame to process.
|
||||
|
||||
Returns:
|
||||
(ndarray): The processed image frame.
|
||||
(numpy.ndarray): The processed image frame with annotations and distance calculations.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from ultralytics.solutions import DistanceCalculation
|
||||
>>> dc = DistanceCalculation()
|
||||
>>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
>>> processed_frame = dc.calculate(frame)
|
||||
"""
|
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
|
|
|
|||
|
|
@ -3,15 +3,40 @@
|
|||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.object_counter import ObjectCounter # Import object counter class
|
||||
from ultralytics.solutions.object_counter import ObjectCounter
|
||||
from ultralytics.utils.plotting import Annotator
|
||||
|
||||
|
||||
class Heatmap(ObjectCounter):
|
||||
"""A class to draw heatmaps in real-time video stream based on their tracks."""
|
||||
"""
|
||||
A class to draw heatmaps in real-time video streams based on object tracks.
|
||||
|
||||
This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
|
||||
streams. It uses tracked object positions to create a cumulative heatmap effect over time.
|
||||
|
||||
Attributes:
|
||||
initialized (bool): Flag indicating whether the heatmap has been initialized.
|
||||
colormap (int): OpenCV colormap used for heatmap visualization.
|
||||
heatmap (np.ndarray): Array storing the cumulative heatmap data.
|
||||
annotator (Annotator): Object for drawing annotations on the image.
|
||||
|
||||
Methods:
|
||||
heatmap_effect: Calculates and updates the heatmap effect for a given bounding box.
|
||||
generate_heatmap: Generates and applies the heatmap effect to each frame.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.solutions import Heatmap
|
||||
>>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET)
|
||||
>>> results = heatmap("path/to/video.mp4")
|
||||
>>> for result in results:
|
||||
... print(result.speed) # Print inference speed
|
||||
... cv2.imshow("Heatmap", result.plot())
|
||||
... if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
... break
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes function for heatmap class with default values."""
|
||||
"""Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.initialized = False # bool variable for heatmap initialization
|
||||
|
|
@ -23,10 +48,15 @@ class Heatmap(ObjectCounter):
|
|||
|
||||
def heatmap_effect(self, box):
|
||||
"""
|
||||
Efficient calculation of heatmap area and effect location for applying colormap.
|
||||
Efficiently calculates heatmap area and effect location for applying colormap.
|
||||
|
||||
Args:
|
||||
box (list): Bounding Box coordinates data [x0, y0, x1, y1]
|
||||
box (List[float]): Bounding box coordinates [x0, y0, x1, y1].
|
||||
|
||||
Examples:
|
||||
>>> heatmap = Heatmap()
|
||||
>>> box = [100, 100, 200, 200]
|
||||
>>> heatmap.heatmap_effect(box)
|
||||
"""
|
||||
x0, y0, x1, y1 = map(int, box)
|
||||
radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
|
||||
|
|
@ -48,9 +78,15 @@ class Heatmap(ObjectCounter):
|
|||
Generate heatmap for each frame using Ultralytics.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): Input image array for processing
|
||||
im0 (np.ndarray): Input image array for processing.
|
||||
|
||||
Returns:
|
||||
im0 (ndarray): Processed image for further usage
|
||||
(np.ndarray): Processed image with heatmap overlay and object counts (if region is specified).
|
||||
|
||||
Examples:
|
||||
>>> heatmap = Heatmap()
|
||||
>>> im0 = cv2.imread("image.jpg")
|
||||
>>> result = heatmap.generate_heatmap(im0)
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
|
||||
|
|
@ -70,16 +106,17 @@ class Heatmap(ObjectCounter):
|
|||
self.store_classwise_counts(cls) # store classwise counts in dict
|
||||
|
||||
# Store tracking previous position and perform object counting
|
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
||||
prev_position = None
|
||||
if len(self.track_history[track_id]) > 1:
|
||||
prev_position = self.track_history[track_id][-2]
|
||||
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
|
||||
|
||||
self.display_counts(im0) if self.region is not None else None # Display the counts on the frame
|
||||
if self.region is not None:
|
||||
self.display_counts(im0) # Display the counts on the frame
|
||||
|
||||
# Normalize, apply colormap to heatmap and combine with original image
|
||||
im0 = (
|
||||
im0
|
||||
if self.track_data.id is None
|
||||
else cv2.addWeighted(
|
||||
if self.track_data.id is not None:
|
||||
im0 = cv2.addWeighted(
|
||||
im0,
|
||||
0.5,
|
||||
cv2.applyColorMap(
|
||||
|
|
@ -88,7 +125,6 @@ class Heatmap(ObjectCounter):
|
|||
0.5,
|
||||
0,
|
||||
)
|
||||
)
|
||||
|
||||
self.display_output(im0) # display output with base class function
|
||||
return im0 # return output image for more usage
|
||||
|
|
|
|||
|
|
@ -1,18 +1,40 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from shapely.geometry import LineString, Point
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
class ObjectCounter(BaseSolution):
|
||||
"""A class to manage the counting of objects in a real-time video stream based on their tracks."""
|
||||
"""
|
||||
A class to manage the counting of objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a
|
||||
specified region in a video stream. It supports both polygonal and linear regions for counting.
|
||||
|
||||
Attributes:
|
||||
in_count (int): Counter for objects moving inward.
|
||||
out_count (int): Counter for objects moving outward.
|
||||
counted_ids (List[int]): List of IDs of objects that have been counted.
|
||||
classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class.
|
||||
region_initialized (bool): Flag indicating whether the counting region has been initialized.
|
||||
show_in (bool): Flag to control display of inward count.
|
||||
show_out (bool): Flag to control display of outward count.
|
||||
|
||||
Methods:
|
||||
count_objects: Counts objects within a polygonal or linear region.
|
||||
store_classwise_counts: Initializes class-wise counts if not already present.
|
||||
display_counts: Displays object counts on the frame.
|
||||
count: Processes input data (frames or object tracks) and updates counts.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_frame = counter.count(frame)
|
||||
>>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialization function for Count class, a child class of BaseSolution class, can be used for counting the
|
||||
objects.
|
||||
"""
|
||||
"""Initializes the ObjectCounter class for real-time object counting in video streams."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_count = 0 # Counter for objects moving inward
|
||||
|
|
@ -26,14 +48,23 @@ class ObjectCounter(BaseSolution):
|
|||
|
||||
def count_objects(self, track_line, box, track_id, prev_position, cls):
|
||||
"""
|
||||
Helper function to count objects within a polygonal region.
|
||||
Counts objects within a polygonal or linear region based on their tracks.
|
||||
|
||||
Args:
|
||||
track_line (dict): last 30 frame track record
|
||||
box (list): Bounding box data for specific track in current frame
|
||||
track_id (int): track ID of the object
|
||||
prev_position (tuple): last frame position coordinates of the track
|
||||
cls (int): Class index for classwise count updates
|
||||
track_line (Dict): Last 30 frame track record for the object.
|
||||
box (List[float]): Bounding box coordinates [x1, y1, x2, y2] for the specific track in the current frame.
|
||||
track_id (int): Unique identifier for the tracked object.
|
||||
prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track.
|
||||
cls (int): Class index for classwise count updates.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]}
|
||||
>>> box = [130, 230, 150, 250]
|
||||
>>> track_id = 1
|
||||
>>> prev_position = (120, 220)
|
||||
>>> cls = 0
|
||||
>>> counter.count_objects(track_line, box, track_id, prev_position, cls)
|
||||
"""
|
||||
if prev_position is None or track_id in self.counted_ids:
|
||||
return
|
||||
|
|
@ -42,7 +73,7 @@ class ObjectCounter(BaseSolution):
|
|||
dx = (box[0] - prev_position[0]) * (centroid.x - prev_position[0])
|
||||
dy = (box[1] - prev_position[1]) * (centroid.y - prev_position[1])
|
||||
|
||||
if len(self.region) >= 3 and self.r_s.contains(Point(track_line[-1])):
|
||||
if len(self.region) >= 3 and self.r_s.contains(self.Point(track_line[-1])):
|
||||
self.counted_ids.append(track_id)
|
||||
# For polygon region
|
||||
if dx > 0:
|
||||
|
|
@ -52,7 +83,7 @@ class ObjectCounter(BaseSolution):
|
|||
self.out_count += 1
|
||||
self.classwise_counts[self.names[cls]]["OUT"] += 1
|
||||
|
||||
elif len(self.region) < 3 and LineString([prev_position, box[:2]]).intersects(self.l_s):
|
||||
elif len(self.region) < 3 and self.LineString([prev_position, box[:2]]).intersects(self.r_s):
|
||||
self.counted_ids.append(track_id)
|
||||
# For linear region
|
||||
if dx > 0 and dy > 0:
|
||||
|
|
@ -64,20 +95,34 @@ class ObjectCounter(BaseSolution):
|
|||
|
||||
def store_classwise_counts(self, cls):
|
||||
"""
|
||||
Initialize class-wise counts if not already present.
|
||||
Initialize class-wise counts for a specific object class if not already present.
|
||||
|
||||
Args:
|
||||
cls (int): Class index for classwise count updates
|
||||
cls (int): Class index for classwise count updates.
|
||||
|
||||
This method ensures that the 'classwise_counts' dictionary contains an entry for the specified class,
|
||||
initializing 'IN' and 'OUT' counts to zero if the class is not already present.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> counter.store_classwise_counts(0) # Initialize counts for class index 0
|
||||
>>> print(counter.classwise_counts)
|
||||
{'person': {'IN': 0, 'OUT': 0}}
|
||||
"""
|
||||
if self.names[cls] not in self.classwise_counts:
|
||||
self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0}
|
||||
|
||||
def display_counts(self, im0):
|
||||
"""
|
||||
Helper function to display object counts on the frame.
|
||||
Displays object counts on the input image or frame.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image or frame
|
||||
im0 (numpy.ndarray): The input image or frame to display counts on.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> counter.display_counts(frame)
|
||||
"""
|
||||
labels_dict = {
|
||||
str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} "
|
||||
|
|
@ -91,12 +136,21 @@ class ObjectCounter(BaseSolution):
|
|||
|
||||
def count(self, im0):
|
||||
"""
|
||||
Processes input data (frames or object tracks) and updates counts.
|
||||
Processes input data (frames or object tracks) and updates object counts.
|
||||
|
||||
This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates
|
||||
object counts, and displays the results on the input image.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image that will be used for processing
|
||||
Returns
|
||||
im0 (ndarray): The processed image for more usage
|
||||
im0 (numpy.ndarray): The input image or frame to be processed.
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): The processed image with annotations and count information.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> processed_frame = counter.count(frame)
|
||||
"""
|
||||
if not self.region_initialized:
|
||||
self.initialize_region()
|
||||
|
|
@ -122,7 +176,9 @@ class ObjectCounter(BaseSolution):
|
|||
)
|
||||
|
||||
# store previous position of track for object counting
|
||||
prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
|
||||
prev_position = None
|
||||
if len(self.track_history[track_id]) > 1:
|
||||
prev_position = self.track_history[track_id][-2]
|
||||
self.count_objects(self.track_line, box, track_id, prev_position, cls) # Perform object counting
|
||||
|
||||
self.display_counts(im0) # Display the counts on the frame
|
||||
|
|
|
|||
|
|
@ -10,10 +10,44 @@ from ultralytics.utils.plotting import Annotator
|
|||
|
||||
|
||||
class ParkingPtsSelection:
|
||||
"""Class for selecting and managing parking zone points on images using a Tkinter-based UI."""
|
||||
"""
|
||||
A class for selecting and managing parking zone points on images using a Tkinter-based UI.
|
||||
|
||||
This class provides functionality to upload an image, select points to define parking zones, and save the
|
||||
selected points to a JSON file. It uses Tkinter for the graphical user interface.
|
||||
|
||||
Attributes:
|
||||
tk (module): The Tkinter module for GUI operations.
|
||||
filedialog (module): Tkinter's filedialog module for file selection operations.
|
||||
messagebox (module): Tkinter's messagebox module for displaying message boxes.
|
||||
master (tk.Tk): The main Tkinter window.
|
||||
canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes.
|
||||
image (PIL.Image.Image): The uploaded image.
|
||||
canvas_image (ImageTk.PhotoImage): The image displayed on the canvas.
|
||||
rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points.
|
||||
current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box.
|
||||
imgw (int): Original width of the uploaded image.
|
||||
imgh (int): Original height of the uploaded image.
|
||||
canvas_max_width (int): Maximum width of the canvas.
|
||||
canvas_max_height (int): Maximum height of the canvas.
|
||||
|
||||
Methods:
|
||||
setup_ui: Sets up the Tkinter UI components.
|
||||
initialize_properties: Initializes the necessary properties.
|
||||
upload_image: Uploads an image, resizes it to fit the canvas, and displays it.
|
||||
on_canvas_click: Handles mouse clicks to add points for bounding boxes.
|
||||
draw_box: Draws a bounding box on the canvas.
|
||||
remove_last_bounding_box: Removes the last bounding box and redraws the canvas.
|
||||
redraw_canvas: Redraws the canvas with the image and all bounding boxes.
|
||||
save_to_json: Saves the bounding boxes to a JSON file.
|
||||
|
||||
Examples:
|
||||
>>> parking_selector = ParkingPtsSelection()
|
||||
>>> # Use the GUI to upload an image, select parking zones, and save the data
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Class initialization method."""
|
||||
"""Initializes the ParkingPtsSelection class, setting up UI and properties for parking zone point selection."""
|
||||
check_requirements("tkinter")
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, messagebox
|
||||
|
|
@ -24,7 +58,7 @@ class ParkingPtsSelection:
|
|||
self.master.mainloop()
|
||||
|
||||
def setup_ui(self):
|
||||
"""Sets up the Tkinter UI components."""
|
||||
"""Sets up the Tkinter UI components for the parking zone points selection interface."""
|
||||
self.master = self.tk.Tk()
|
||||
self.master.title("Ultralytics Parking Zones Points Selector")
|
||||
self.master.resizable(False, False)
|
||||
|
|
@ -45,14 +79,14 @@ class ParkingPtsSelection:
|
|||
self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT)
|
||||
|
||||
def initialize_properties(self):
|
||||
"""Initialize the necessary properties."""
|
||||
"""Initialize properties for image, canvas, bounding boxes, and dimensions."""
|
||||
self.image = self.canvas_image = None
|
||||
self.rg_data, self.current_box = [], []
|
||||
self.imgw = self.imgh = 0
|
||||
self.canvas_max_width, self.canvas_max_height = 1280, 720
|
||||
|
||||
def upload_image(self):
|
||||
"""Uploads an image, resizes it to fit the canvas, and displays it."""
|
||||
"""Uploads and displays an image on the canvas, resizing it to fit within specified dimensions."""
|
||||
from PIL import Image, ImageTk # scope because ImageTk requires tkinter package
|
||||
|
||||
self.image = Image.open(self.filedialog.askopenfilename(filetypes=[("Image Files", "*.png;*.jpg;*.jpeg")]))
|
||||
|
|
@ -76,7 +110,7 @@ class ParkingPtsSelection:
|
|||
self.rg_data.clear(), self.current_box.clear()
|
||||
|
||||
def on_canvas_click(self, event):
|
||||
"""Handles mouse clicks to add points for bounding boxes."""
|
||||
"""Handles mouse clicks to add points for bounding boxes on the canvas."""
|
||||
self.current_box.append((event.x, event.y))
|
||||
self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red")
|
||||
if len(self.current_box) == 4:
|
||||
|
|
@ -85,12 +119,12 @@ class ParkingPtsSelection:
|
|||
self.current_box.clear()
|
||||
|
||||
def draw_box(self, box):
|
||||
"""Draws a bounding box on the canvas."""
|
||||
"""Draws a bounding box on the canvas using the provided coordinates."""
|
||||
for i in range(4):
|
||||
self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2)
|
||||
|
||||
def remove_last_bounding_box(self):
|
||||
"""Removes the last bounding box and redraws the canvas."""
|
||||
"""Removes the last bounding box from the list and redraws the canvas."""
|
||||
if not self.rg_data:
|
||||
self.messagebox.showwarning("Warning", "No bounding boxes to remove.")
|
||||
return
|
||||
|
|
@ -105,7 +139,7 @@ class ParkingPtsSelection:
|
|||
self.draw_box(box)
|
||||
|
||||
def save_to_json(self):
|
||||
"""Saves the bounding boxes to a JSON file."""
|
||||
"""Saves the selected parking zone points to a JSON file with scaled coordinates."""
|
||||
scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()
|
||||
data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]
|
||||
with open("bounding_boxes.json", "w") as f:
|
||||
|
|
@ -114,7 +148,30 @@ class ParkingPtsSelection:
|
|||
|
||||
|
||||
class ParkingManagement(BaseSolution):
|
||||
"""Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization."""
|
||||
"""
|
||||
Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.
|
||||
|
||||
This class extends BaseSolution to provide functionality for parking lot management, including detection of
|
||||
occupied spaces, visualization of parking regions, and display of occupancy statistics.
|
||||
|
||||
Attributes:
|
||||
json_file (str): Path to the JSON file containing parking region details.
|
||||
json (List[Dict]): Loaded JSON data containing parking region information.
|
||||
pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).
|
||||
arc (Tuple[int, int, int]): RGB color tuple for available region visualization.
|
||||
occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization.
|
||||
dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.
|
||||
|
||||
Methods:
|
||||
process_data: Processes model data for parking lot management and visualization.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.solutions import ParkingManagement
|
||||
>>> parking_manager = ParkingManagement(model="yolov8n.pt", json_file="parking_regions.json")
|
||||
>>> results = parking_manager(source="parking_lot_video.mp4")
|
||||
>>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}")
|
||||
>>> print(f"Available spaces: {parking_manager.pr_info['Available']}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the parking management system with a YOLO model and visualization settings."""
|
||||
|
|
@ -136,10 +193,19 @@ class ParkingManagement(BaseSolution):
|
|||
|
||||
def process_data(self, im0):
|
||||
"""
|
||||
Process the model data for parking lot management.
|
||||
Processes the model data for parking lot management.
|
||||
|
||||
This function analyzes the input image, extracts tracks, and determines the occupancy status of parking
|
||||
regions defined in the JSON file. It annotates the image with occupied and available parking spots,
|
||||
and updates the parking information.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): inference image.
|
||||
im0 (np.ndarray): The input inference image.
|
||||
|
||||
Examples:
|
||||
>>> parking_manager = ParkingManagement(json_file="parking_regions.json")
|
||||
>>> image = cv2.imread("parking_lot.jpg")
|
||||
>>> parking_manager.process_data(image)
|
||||
"""
|
||||
self.extract_tracks(im0) # extract tracks from im0
|
||||
es, fs = len(self.json), 0 # empty slots, filled slots
|
||||
|
|
|
|||
|
|
@ -1,16 +1,40 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from shapely.geometry import Point
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution # Import a parent class
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
class QueueManager(BaseSolution):
|
||||
"""A class to manage the queue in a real-time video stream based on object tracks."""
|
||||
"""
|
||||
Manages queue counting in real-time video streams based on object tracks.
|
||||
|
||||
This class extends BaseSolution to provide functionality for tracking and counting objects within a specified
|
||||
region in video frames.
|
||||
|
||||
Attributes:
|
||||
counts (int): The current count of objects in the queue.
|
||||
rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.
|
||||
region_length (int): The number of points defining the queue region.
|
||||
annotator (Annotator): An instance of the Annotator class for drawing on frames.
|
||||
track_line (List[Tuple[int, int]]): List of track line coordinates.
|
||||
track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object.
|
||||
|
||||
Methods:
|
||||
initialize_region: Initializes the queue region.
|
||||
process_queue: Processes a single frame for queue management.
|
||||
extract_tracks: Extracts object tracks from the current frame.
|
||||
store_tracking_history: Stores the tracking history for an object.
|
||||
display_output: Displays the processed output.
|
||||
|
||||
Examples:
|
||||
>>> queue_manager = QueueManager(source="video.mp4", region=[100, 100, 200, 200, 300, 300])
|
||||
>>> for frame in video_stream:
|
||||
... processed_frame = queue_manager.process_queue(frame)
|
||||
... cv2.imshow("Queue Management", processed_frame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the QueueManager with specified parameters for tracking and counting objects."""
|
||||
"""Initializes the QueueManager with parameters for tracking and counting objects in a video stream."""
|
||||
super().__init__(**kwargs)
|
||||
self.initialize_region()
|
||||
self.counts = 0 # Queue counts Information
|
||||
|
|
@ -19,12 +43,31 @@ class QueueManager(BaseSolution):
|
|||
|
||||
def process_queue(self, im0):
|
||||
"""
|
||||
Main function to start the queue management process.
|
||||
Processes the queue management for a single frame of video.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image that will be used for processing
|
||||
Returns
|
||||
im0 (ndarray): The processed image for more usage
|
||||
im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream.
|
||||
|
||||
Returns:
|
||||
(numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts.
|
||||
|
||||
This method performs the following steps:
|
||||
1. Resets the queue count for the current frame.
|
||||
2. Initializes an Annotator object for drawing on the image.
|
||||
3. Extracts tracks from the image.
|
||||
4. Draws the counting region on the image.
|
||||
5. For each detected object:
|
||||
- Draws bounding boxes and labels.
|
||||
- Stores tracking history.
|
||||
- Draws centroids and tracks.
|
||||
- Checks if the object is inside the counting region and updates the count.
|
||||
6. Displays the queue count on the image.
|
||||
7. Displays the processed output.
|
||||
|
||||
Examples:
|
||||
>>> queue_manager = QueueManager()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_frame = queue_manager.process_queue(frame)
|
||||
"""
|
||||
self.counts = 0 # Reset counts every frame
|
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
|
@ -48,8 +91,10 @@ class QueueManager(BaseSolution):
|
|||
track_history = self.track_history.get(track_id, [])
|
||||
|
||||
# store previous position of track and check if the object is inside the counting region
|
||||
prev_position = track_history[-2] if len(track_history) > 1 else None
|
||||
if self.region_length >= 3 and prev_position and self.r_s.contains(Point(self.track_line[-1])):
|
||||
prev_position = None
|
||||
if len(track_history) > 1:
|
||||
prev_position = track_history[-2]
|
||||
if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):
|
||||
self.counts += 1
|
||||
|
||||
# Display queue counts
|
||||
|
|
|
|||
|
|
@ -9,21 +9,51 @@ from ultralytics import YOLO
|
|||
from ultralytics.utils import LOGGER, yaml_load
|
||||
from ultralytics.utils.checks import check_imshow, check_requirements
|
||||
|
||||
check_requirements("shapely>=2.0.0")
|
||||
from shapely.geometry import LineString, Polygon
|
||||
|
||||
DEFAULT_SOL_CFG_PATH = Path(__file__).resolve().parents[1] / "cfg/solutions/default.yaml"
|
||||
|
||||
|
||||
class BaseSolution:
|
||||
"""A class to manage all the Ultralytics Solutions: https://docs.ultralytics.com/solutions/."""
|
||||
"""
|
||||
A base class for managing Ultralytics Solutions.
|
||||
|
||||
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
|
||||
and region initialization.
|
||||
|
||||
Attributes:
|
||||
LineString (shapely.geometry.LineString): Class for creating line string geometries.
|
||||
Polygon (shapely.geometry.Polygon): Class for creating polygon geometries.
|
||||
Point (shapely.geometry.Point): Class for creating point geometries.
|
||||
CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs.
|
||||
region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest.
|
||||
line_width (int): Width of lines used in visualizations.
|
||||
model (ultralytics.YOLO): Loaded YOLO model instance.
|
||||
names (Dict[int, str]): Dictionary mapping class indices to class names.
|
||||
env_check (bool): Flag indicating whether the environment supports image display.
|
||||
track_history (collections.defaultdict): Dictionary to store tracking history for each object.
|
||||
|
||||
Methods:
|
||||
extract_tracks: Apply object tracking and extract tracks from an input image.
|
||||
store_tracking_history: Store object tracking history for a given track ID and bounding box.
|
||||
initialize_region: Initialize the counting region and line segment based on configuration.
|
||||
display_output: Display the results of processing, including showing frames or saving results.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
|
||||
>>> solution.initialize_region()
|
||||
>>> image = cv2.imread("image.jpg")
|
||||
>>> solution.extract_tracks(image)
|
||||
>>> solution.display_output(image)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Base initializer for all solutions.
|
||||
"""Initializes the BaseSolution class with configuration settings and YOLO model for Ultralytics solutions."""
|
||||
check_requirements("shapely>=2.0.0")
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
|
||||
self.LineString = LineString
|
||||
self.Polygon = Polygon
|
||||
self.Point = Point
|
||||
|
||||
Child classes should call this with necessary parameters.
|
||||
"""
|
||||
# Load config and update with args
|
||||
self.CFG = yaml_load(DEFAULT_SOL_CFG_PATH)
|
||||
self.CFG.update(kwargs)
|
||||
|
|
@ -42,10 +72,15 @@ class BaseSolution:
|
|||
|
||||
def extract_tracks(self, im0):
|
||||
"""
|
||||
Apply object tracking and extract tracks.
|
||||
Applies object tracking and extracts tracks from an input image or frame.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image or frame
|
||||
im0 (ndarray): The input image or frame.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> solution.extract_tracks(frame)
|
||||
"""
|
||||
self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"])
|
||||
|
||||
|
|
@ -62,11 +97,18 @@ class BaseSolution:
|
|||
|
||||
def store_tracking_history(self, track_id, box):
|
||||
"""
|
||||
Store object tracking history.
|
||||
Stores the tracking history of an object.
|
||||
|
||||
This method updates the tracking history for a given object by appending the center point of its
|
||||
bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
|
||||
|
||||
Args:
|
||||
track_id (int): The track ID of the object
|
||||
box (list): Bounding box coordinates of the object
|
||||
track_id (int): The unique identifier for the tracked object.
|
||||
box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> solution.store_tracking_history(1, [100, 200, 300, 400])
|
||||
"""
|
||||
# Store tracking history
|
||||
self.track_line = self.track_history[track_id]
|
||||
|
|
@ -75,19 +117,32 @@ class BaseSolution:
|
|||
self.track_line.pop(0)
|
||||
|
||||
def initialize_region(self):
|
||||
"""Initialize the counting region and line segment based on config."""
|
||||
self.region = [(20, 400), (1080, 404), (1080, 360), (20, 360)] if self.region is None else self.region
|
||||
self.r_s = Polygon(self.region) if len(self.region) >= 3 else LineString(self.region) # region segment
|
||||
self.l_s = LineString(
|
||||
[(self.region[0][0], self.region[0][1]), (self.region[1][0], self.region[1][1])]
|
||||
) # line segment
|
||||
"""Initialize the counting region and line segment based on configuration settings."""
|
||||
if self.region is None:
|
||||
self.region = [(20, 400), (1080, 404), (1080, 360), (20, 360)]
|
||||
self.r_s = (
|
||||
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
|
||||
) # region or line
|
||||
|
||||
def display_output(self, im0):
|
||||
"""
|
||||
Display the results of the processing, which could involve showing frames, printing counts, or saving results.
|
||||
|
||||
This method is responsible for visualizing the output of the object detection and tracking process. It displays
|
||||
the processed frame with annotations, and allows for user interaction to close the display.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image or frame
|
||||
im0 (numpy.ndarray): The input image or frame that has been processed and annotated.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> solution.display_output(frame)
|
||||
|
||||
Notes:
|
||||
- This method will only display output if the 'show' configuration is set to True and the environment
|
||||
supports image display.
|
||||
- The display can be closed by pressing the 'q' key.
|
||||
"""
|
||||
if self.CFG.get("show") and self.env_check:
|
||||
cv2.imshow("Ultralytics Solutions", im0)
|
||||
|
|
|
|||
|
|
@ -4,15 +4,43 @@ from time import time
|
|||
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, LineString
|
||||
from ultralytics.solutions.solutions import BaseSolution
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
class SpeedEstimator(BaseSolution):
|
||||
"""A class to estimate the speed of objects in a real-time video stream based on their tracks."""
|
||||
"""
|
||||
A class to estimate the speed of objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for estimating object speeds using
|
||||
tracking data in video streams.
|
||||
|
||||
Attributes:
|
||||
spd (Dict[int, float]): Dictionary storing speed data for tracked objects.
|
||||
trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated.
|
||||
trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects.
|
||||
trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects.
|
||||
annotator (Annotator): Annotator object for drawing on images.
|
||||
region (List[Tuple[int, int]]): List of points defining the speed estimation region.
|
||||
track_line (List[Tuple[float, float]]): List of points representing the object's track.
|
||||
r_s (LineString): LineString object representing the speed estimation region.
|
||||
|
||||
Methods:
|
||||
initialize_region: Initializes the speed estimation region.
|
||||
estimate_speed: Estimates the speed of objects based on tracking data.
|
||||
store_tracking_history: Stores the tracking history for an object.
|
||||
extract_tracks: Extracts tracks from the current frame.
|
||||
display_output: Displays the output with annotations.
|
||||
|
||||
Examples:
|
||||
>>> estimator = SpeedEstimator()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_frame = estimator.estimate_speed(frame)
|
||||
>>> cv2.imshow("Speed Estimation", processed_frame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initializes the SpeedEstimator with the given parameters."""
|
||||
"""Initializes the SpeedEstimator object with speed estimation parameters and data structures."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.initialize_region() # Initialize speed region
|
||||
|
|
@ -27,9 +55,15 @@ class SpeedEstimator(BaseSolution):
|
|||
Estimates the speed of objects based on tracking data.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): The input image that will be used for processing
|
||||
Returns
|
||||
im0 (ndarray): The processed image for more usage
|
||||
im0 (np.ndarray): Input image for processing. Shape is typically (H, W, C) for RGB images.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed image with speed estimations and annotations.
|
||||
|
||||
Examples:
|
||||
>>> estimator = SpeedEstimator()
|
||||
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
>>> processed_image = estimator.estimate_speed(image)
|
||||
"""
|
||||
self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
|
|
@ -56,7 +90,7 @@ class SpeedEstimator(BaseSolution):
|
|||
)
|
||||
|
||||
# Calculate object speed and direction based on region intersection
|
||||
if LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.l_s):
|
||||
if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s):
|
||||
direction = "known"
|
||||
else:
|
||||
direction = "unknown"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
|||
|
||||
|
||||
def inference(model=None):
|
||||
"""Runs real-time object detection on video input using Ultralytics YOLO11 in a Streamlit application."""
|
||||
"""Performs real-time object detection on video input using YOLO in a Streamlit web application."""
|
||||
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
||||
import streamlit as st
|
||||
|
||||
|
|
@ -108,7 +108,7 @@ def inference(model=None):
|
|||
st.warning("Failed to read frame from webcam. Please make sure the webcam is connected properly.")
|
||||
break
|
||||
|
||||
prev_time = time.time()
|
||||
prev_time = time.time() # Store initial time for FPS calculation
|
||||
|
||||
# Store model predictions
|
||||
if enable_trk == "Yes":
|
||||
|
|
@ -120,7 +120,6 @@ def inference(model=None):
|
|||
# Calculate model FPS
|
||||
curr_time = time.time()
|
||||
fps = 1 / (curr_time - prev_time)
|
||||
prev_time = curr_time
|
||||
|
||||
# display frame
|
||||
org_frame.image(frame, channels="BGR")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue