Update Tracker docstrings (#15469)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
da2797a182
commit
b7c5db94b4
10 changed files with 501 additions and 196 deletions
|
|
@ -15,6 +15,9 @@ class BOTrack(STrack):
|
|||
"""
|
||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||
|
||||
This class extends the STrack class to include additional functionalities for object tracking, such as feature
|
||||
smoothing, Kalman filter prediction, and reactivation of tracks.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||
|
|
@ -34,16 +37,35 @@ class BOTrack(STrack):
|
|||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||
|
||||
Usage:
|
||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
bo_track.predict()
|
||||
bo_track.update(new_track, frame_id)
|
||||
Examples:
|
||||
Create a BOTrack instance and update its features
|
||||
>>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
|
||||
>>> bo_track.predict()
|
||||
>>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
|
||||
>>> bo_track.update(new_track, frame_id=2)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
||||
"""
|
||||
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
||||
|
||||
Args:
|
||||
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
|
||||
score (float): Confidence score of the detection.
|
||||
cls (int): Class ID of the detected object.
|
||||
feat (np.ndarray | None): Feature vector associated with the detection.
|
||||
feat_history (int): Maximum length of the feature history deque.
|
||||
|
||||
Examples:
|
||||
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
|
||||
>>> tlwh = np.array([100, 50, 80, 120])
|
||||
>>> score = 0.9
|
||||
>>> cls = 1
|
||||
>>> feat = np.random.rand(128)
|
||||
>>> bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
"""
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
|
|
@ -54,7 +76,7 @@ class BOTrack(STrack):
|
|||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
"""Update features vector and smooth it using exponential moving average."""
|
||||
"""Update the feature vector and apply exponential moving average smoothing."""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
|
|
@ -65,7 +87,7 @@ class BOTrack(STrack):
|
|||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""Predicts the mean and covariance using Kalman filter."""
|
||||
"""Predicts the object's future state using the Kalman filter to update its mean and covariance."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
|
|
@ -80,14 +102,14 @@ class BOTrack(STrack):
|
|||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
||||
"""Updates the YOLOv8 instance with new track information and the current frame ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
||||
"""Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -96,7 +118,7 @@ class BOTrack(STrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
||||
"""Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
|
|
@ -111,12 +133,12 @@ class BOTrack(STrack):
|
|||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
||||
"""Converts tlwh bounding box coordinates to xywh format."""
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
||||
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
|
@ -129,9 +151,9 @@ class BOTSORT(BYTETracker):
|
|||
Attributes:
|
||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
args (Any): Parsed command-line arguments containing tracking parameters.
|
||||
|
||||
Methods:
|
||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||
|
|
@ -139,17 +161,29 @@ class BOTSORT(BYTETracker):
|
|||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||
|
||||
Usage:
|
||||
bot_sort = BOTSORT(args, frame_rate)
|
||||
bot_sort.init_track(dets, scores, cls, img)
|
||||
bot_sort.multi_predict(tracks)
|
||||
Examples:
|
||||
Initialize BOTSORT and process detections
|
||||
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||
>>> bot_sort.init_track(dets, scores, cls, img)
|
||||
>>> bot_sort.multi_predict(tracks)
|
||||
|
||||
Note:
|
||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
"""
|
||||
Initialize YOLOv8 object with ReID module and GMC algorithm.
|
||||
|
||||
Args:
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
frame_rate (int): Frame rate of the video being processed.
|
||||
|
||||
Examples:
|
||||
Initialize BOTSORT with command-line arguments and a specified frame rate:
|
||||
>>> args = parse_args()
|
||||
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||
"""
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
|
|
@ -161,11 +195,11 @@ class BOTSORT(BYTETracker):
|
|||
self.gmc = GMC(method=args.gmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
"""Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
|
|
@ -175,7 +209,7 @@ class BOTSORT(BYTETracker):
|
|||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
"""Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = dists > self.proximity_thresh
|
||||
|
||||
|
|
@ -190,10 +224,10 @@ class BOTSORT(BYTETracker):
|
|||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
"""Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
||||
BOTrack.multi_predict(tracks)
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
"""Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
||||
super().reset()
|
||||
self.gmc.reset_params()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue