Update Tracker docstrings (#15469)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-14 00:32:15 +08:00 committed by GitHub
parent da2797a182
commit b7c5db94b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 501 additions and 196 deletions

View file

@ -15,6 +15,9 @@ class BOTrack(STrack):
"""
An extended version of the STrack class for YOLOv8, adding object tracking features.
This class extends the STrack class to include additional functionalities for object tracking, such as feature
smoothing, Kalman filter prediction, and reactivation of tracks.
Attributes:
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
smooth_feat (np.ndarray): Smoothed feature vector.
@ -34,16 +37,35 @@ class BOTrack(STrack):
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
Usage:
bo_track = BOTrack(tlwh, score, cls, feat)
bo_track.predict()
bo_track.update(new_track, frame_id)
Examples:
Create a BOTrack instance and update its features
>>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
>>> bo_track.predict()
>>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
>>> bo_track.update(new_track, frame_id=2)
"""
shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
"""
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
Args:
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
score (float): Confidence score of the detection.
cls (int): Class ID of the detected object.
feat (np.ndarray | None): Feature vector associated with the detection.
feat_history (int): Maximum length of the feature history deque.
Examples:
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
>>> tlwh = np.array([100, 50, 80, 120])
>>> score = 0.9
>>> cls = 1
>>> feat = np.random.rand(128)
>>> bo_track = BOTrack(tlwh, score, cls, feat)
"""
super().__init__(tlwh, score, cls)
self.smooth_feat = None
@ -54,7 +76,7 @@ class BOTrack(STrack):
self.alpha = 0.9
def update_features(self, feat):
"""Update features vector and smooth it using exponential moving average."""
"""Update the feature vector and apply exponential moving average smoothing."""
feat /= np.linalg.norm(feat)
self.curr_feat = feat
if self.smooth_feat is None:
@ -65,7 +87,7 @@ class BOTrack(STrack):
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self):
"""Predicts the mean and covariance using Kalman filter."""
"""Predicts the object's future state using the Kalman filter to update its mean and covariance."""
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[6] = 0
@ -80,14 +102,14 @@ class BOTrack(STrack):
super().re_activate(new_track, frame_id, new_id)
def update(self, new_track, frame_id):
"""Update the YOLOv8 instance with new track and frame ID."""
"""Updates the YOLOv8 instance with new track information and the current frame ID."""
if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat)
super().update(new_track, frame_id)
@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
"""Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
@ -96,7 +118,7 @@ class BOTrack(STrack):
@staticmethod
def multi_predict(stracks):
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
"""Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
if len(stracks) <= 0:
return
multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -111,12 +133,12 @@ class BOTrack(STrack):
stracks[i].covariance = cov
def convert_coords(self, tlwh):
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
"""Converts tlwh bounding box coordinates to xywh format."""
return self.tlwh_to_xywh(tlwh)
@staticmethod
def tlwh_to_xywh(tlwh):
"""Convert bounding box to format `(center x, center y, width, height)`."""
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
@ -129,9 +151,9 @@ class BOTSORT(BYTETracker):
Attributes:
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
gmc (GMC): An instance of the GMC algorithm for data association.
args (object): Parsed command-line arguments containing tracking parameters.
args (Any): Parsed command-line arguments containing tracking parameters.
Methods:
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
@ -139,17 +161,29 @@ class BOTSORT(BYTETracker):
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
Usage:
bot_sort = BOTSORT(args, frame_rate)
bot_sort.init_track(dets, scores, cls, img)
bot_sort.multi_predict(tracks)
Examples:
Initialize BOTSORT and process detections
>>> bot_sort = BOTSORT(args, frame_rate=30)
>>> bot_sort.init_track(dets, scores, cls, img)
>>> bot_sort.multi_predict(tracks)
Note:
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
"""
def __init__(self, args, frame_rate=30):
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
"""
Initialize YOLOv8 object with ReID module and GMC algorithm.
Args:
args (object): Parsed command-line arguments containing tracking parameters.
frame_rate (int): Frame rate of the video being processed.
Examples:
Initialize BOTSORT with command-line arguments and a specified frame rate:
>>> args = parse_args()
>>> bot_sort = BOTSORT(args, frame_rate=30)
"""
super().__init__(args, frame_rate)
# ReID module
self.proximity_thresh = args.proximity_thresh
@ -161,11 +195,11 @@ class BOTSORT(BYTETracker):
self.gmc = GMC(method=args.gmc_method)
def get_kalmanfilter(self):
"""Returns an instance of KalmanFilterXYWH for object tracking."""
"""Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
return KalmanFilterXYWH()
def init_track(self, dets, scores, cls, img=None):
"""Initialize track with detections, scores, and classes."""
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
if len(dets) == 0:
return []
if self.args.with_reid and self.encoder is not None:
@ -175,7 +209,7 @@ class BOTSORT(BYTETracker):
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
"""Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
dists = matching.iou_distance(tracks, detections)
dists_mask = dists > self.proximity_thresh
@ -190,10 +224,10 @@ class BOTSORT(BYTETracker):
return dists
def multi_predict(self, tracks):
"""Predict and track multiple objects with YOLOv8 model."""
"""Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
BOTrack.multi_predict(tracks)
def reset(self):
"""Reset tracker."""
"""Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
super().reset()
self.gmc.reset_params()