Update Tracker docstrings (#15469)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-14 00:32:15 +08:00 committed by GitHub
parent da2797a182
commit b7c5db94b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 501 additions and 196 deletions

View file

@ -25,7 +25,7 @@ class STrack(BaseTrack):
is_activated (bool): Boolean flag indicating if the track has been activated.
score (float): Confidence score of the track.
tracklet_len (int): Length of the tracklet.
cls (any): Class label for the object.
cls (Any): Class label for the object.
idx (int): Index or identifier for the object.
frame_id (int): Current frame ID.
start_frame (int): Frame where the object was first detected.
@ -39,12 +39,31 @@ class STrack(BaseTrack):
update(new_track, frame_id): Update the state of a matched track.
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
Examples:
Initialize and activate a new track
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
"""
shared_kalman = KalmanFilterXYAH()
def __init__(self, xywh, score, cls):
"""Initialize new STrack instance."""
"""
Initialize a new STrack instance.
Args:
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
score (float): Confidence score of the detection.
cls (Any): Class label for the detected object.
Examples:
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
>>> score = 0.9
>>> cls = 'person'
>>> track = STrack(xywh, score, cls)
"""
super().__init__()
# xywh+idx or xywha+idx
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
@ -60,7 +79,7 @@ class STrack(BaseTrack):
self.angle = xywh[4] if len(xywh) == 6 else None
def predict(self):
"""Predicts mean and covariance using Kalman filter."""
"""Predicts the next state (mean and covariance) of the object using the Kalman filter."""
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
mean_state[7] = 0
@ -68,7 +87,7 @@ class STrack(BaseTrack):
@staticmethod
def multi_predict(stracks):
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
if len(stracks) <= 0:
return
multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -83,7 +102,7 @@ class STrack(BaseTrack):
@staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)):
"""Update state tracks positions and covariances using a homography matrix."""
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks])
@ -101,7 +120,7 @@ class STrack(BaseTrack):
stracks[i].covariance = cov
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet."""
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
@ -114,7 +133,7 @@ class STrack(BaseTrack):
self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a previously lost track with a new detection."""
"""Reactivates a previously lost track using new detection data and updates its state and attributes."""
self.mean, self.covariance = self.kalman_filter.update(
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
)
@ -136,6 +155,12 @@ class STrack(BaseTrack):
Args:
new_track (STrack): The new track containing updated information.
frame_id (int): The ID of the current frame.
Examples:
Update the state of a track with new detection information
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
>>> track.update(new_track, 2)
"""
self.frame_id = frame_id
self.tracklet_len += 1
@ -158,7 +183,7 @@ class STrack(BaseTrack):
@property
def tlwh(self):
"""Get current position in bounding box format (top left x, top left y, width, height)."""
"""Returns the bounding box in top-left-width-height format from the current state estimate."""
if self.mean is None:
return self._tlwh.copy()
ret = self.mean[:4].copy()
@ -168,16 +193,14 @@ class STrack(BaseTrack):
@property
def xyxy(self):
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
"""Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
ret = self.tlwh.copy()
ret[2:] += ret[:2]
return ret
@staticmethod
def tlwh_to_xyah(tlwh):
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
height.
"""
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
ret = np.asarray(tlwh).copy()
ret[:2] += ret[2:] / 2
ret[2] /= ret[3]
@ -185,14 +208,14 @@ class STrack(BaseTrack):
@property
def xywh(self):
"""Get current position in bounding box format (center x, center y, width, height)."""
"""Returns the current position of the bounding box in (center x, center y, width, height) format."""
ret = np.asarray(self.tlwh).copy()
ret[:2] += ret[2:] / 2
return ret
@property
def xywha(self):
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
"""Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
if self.angle is None:
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
return self.xywh
@ -200,12 +223,12 @@ class STrack(BaseTrack):
@property
def result(self):
"""Get current tracking results."""
"""Returns the current tracking results in the appropriate bounding box format."""
coords = self.xyxy if self.angle is None else self.xywha
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
"""Returns a string representation of the STrack object including start frame, end frame, and track ID."""
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
@ -213,18 +236,18 @@ class BYTETracker:
"""
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
predicting the new object locations, and performs data association.
Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
the new object locations, and performs data association.
Attributes:
tracked_stracks (list[STrack]): List of successfully activated tracks.
lost_stracks (list[STrack]): List of lost tracks.
removed_stracks (list[STrack]): List of removed tracks.
tracked_stracks (List[STrack]): List of successfully activated tracks.
lost_stracks (List[STrack]): List of lost tracks.
removed_stracks (List[STrack]): List of removed tracks.
frame_id (int): The current frame ID.
args (namespace): Command-line arguments.
args (Namespace): Command-line arguments.
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
kalman_filter (object): Kalman Filter object.
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
Methods:
update(results, img=None): Updates object tracker with new detections.
@ -236,10 +259,27 @@ class BYTETracker:
joint_stracks(tlista, tlistb): Combines two lists of stracks.
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
Examples:
Initialize BYTETracker and update with detection results
>>> tracker = BYTETracker(args, frame_rate=30)
>>> results = yolo_model.detect(image)
>>> tracked_objects = tracker.update(results)
"""
def __init__(self, args, frame_rate=30):
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
"""
Initialize a BYTETracker instance for object tracking.
Args:
args (Namespace): Command-line arguments containing tracking parameters.
frame_rate (int): Frame rate of the video sequence.
Examples:
Initialize BYTETracker with command-line arguments and a frame rate of 30
>>> args = Namespace(track_buffer=30)
>>> tracker = BYTETracker(args, frame_rate=30)
"""
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
@ -251,7 +291,7 @@ class BYTETracker:
self.reset_id()
def update(self, results, img=None):
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
"""Updates the tracker with new detections and returns the current list of tracked objects."""
self.frame_id += 1
activated_stracks = []
refind_stracks = []
@ -365,31 +405,31 @@ class BYTETracker:
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
"""Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
return KalmanFilterXYAH()
def init_track(self, dets, scores, cls, img=None):
"""Initialize object tracking with detections and scores using STrack algorithm."""
"""Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
def get_dists(self, tracks, detections):
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
"""Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
dists = matching.iou_distance(tracks, detections)
if self.args.fuse_score:
dists = matching.fuse_score(dists, detections)
return dists
def multi_predict(self, tracks):
"""Returns the predicted tracks using the YOLOv8 network."""
"""Predict the next states for multiple tracks using Kalman filter."""
STrack.multi_predict(tracks)
@staticmethod
def reset_id():
"""Resets the ID counter of STrack."""
"""Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
STrack.reset_id()
def reset(self):
"""Reset tracker."""
"""Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack]
@ -399,7 +439,7 @@ class BYTETracker:
@staticmethod
def joint_stracks(tlista, tlistb):
"""Combine two lists of stracks into a single one."""
"""Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
exists = {}
res = []
for t in tlista:
@ -414,20 +454,13 @@ class BYTETracker:
@staticmethod
def sub_stracks(tlista, tlistb):
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
stracks = {t.track_id: t for t in tlista}
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
"""
"""Filters out the stracks present in the second list from the first list."""
track_ids_b = {t.track_id for t in tlistb}
return [t for t in tlista if t.track_id not in track_ids_b]
@staticmethod
def remove_duplicate_stracks(stracksa, stracksb):
"""Remove duplicate stracks with non-maximum IoU distance."""
"""Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15)
dupa, dupb = [], []