Update Tracker docstrings (#15469)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
da2797a182
commit
b7c5db94b4
10 changed files with 501 additions and 196 deletions
|
|
@ -25,7 +25,7 @@ class STrack(BaseTrack):
|
|||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||
score (float): Confidence score of the track.
|
||||
tracklet_len (int): Length of the tracklet.
|
||||
cls (any): Class label for the object.
|
||||
cls (Any): Class label for the object.
|
||||
idx (int): Index or identifier for the object.
|
||||
frame_id (int): Current frame ID.
|
||||
start_frame (int): Frame where the object was first detected.
|
||||
|
|
@ -39,12 +39,31 @@ class STrack(BaseTrack):
|
|||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
|
||||
Examples:
|
||||
Initialize and activate a new track
|
||||
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
|
||||
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, xywh, score, cls):
|
||||
"""Initialize new STrack instance."""
|
||||
"""
|
||||
Initialize a new STrack instance.
|
||||
|
||||
Args:
|
||||
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
|
||||
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
||||
score (float): Confidence score of the detection.
|
||||
cls (Any): Class label for the detected object.
|
||||
|
||||
Examples:
|
||||
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
|
||||
>>> score = 0.9
|
||||
>>> cls = 'person'
|
||||
>>> track = STrack(xywh, score, cls)
|
||||
"""
|
||||
super().__init__()
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
||||
|
|
@ -60,7 +79,7 @@ class STrack(BaseTrack):
|
|||
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
"""Predicts the next state (mean and covariance) of the object using the Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
|
|
@ -68,7 +87,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
||||
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
|
|
@ -83,7 +102,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
"""Update state tracks positions and covariances using a homography matrix."""
|
||||
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
|
@ -101,7 +120,7 @@ class STrack(BaseTrack):
|
|||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet."""
|
||||
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||
|
|
@ -114,7 +133,7 @@ class STrack(BaseTrack):
|
|||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
"""Reactivates a previously lost track using new detection data and updates its state and attributes."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
||||
)
|
||||
|
|
@ -136,6 +155,12 @@ class STrack(BaseTrack):
|
|||
Args:
|
||||
new_track (STrack): The new track containing updated information.
|
||||
frame_id (int): The ID of the current frame.
|
||||
|
||||
Examples:
|
||||
Update the state of a track with new detection information
|
||||
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
|
||||
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
|
||||
>>> track.update(new_track, 2)
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
|
@ -158,7 +183,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
||||
"""Returns the bounding box in top-left-width-height format from the current state estimate."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -168,16 +193,14 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
"""Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
||||
height.
|
||||
"""
|
||||
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
|
|
@ -185,14 +208,14 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
||||
"""Returns the current position of the bounding box in (center x, center y, width, height) format."""
|
||||
ret = np.asarray(self.tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xywha(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
||||
"""Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
||||
if self.angle is None:
|
||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||
return self.xywh
|
||||
|
|
@ -200,12 +223,12 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def result(self):
|
||||
"""Get current tracking results."""
|
||||
"""Returns the current tracking results in the appropriate bounding box format."""
|
||||
coords = self.xyxy if self.angle is None else self.xywha
|
||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
"""Returns a string representation of the STrack object including start frame, end frame, and track ID."""
|
||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||
|
||||
|
||||
|
|
@ -213,18 +236,18 @@ class BYTETracker:
|
|||
"""
|
||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||
|
||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
||||
predicting the new object locations, and performs data association.
|
||||
Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
|
||||
It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
|
||||
the new object locations, and performs data association.
|
||||
|
||||
Attributes:
|
||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (list[STrack]): List of lost tracks.
|
||||
removed_stracks (list[STrack]): List of removed tracks.
|
||||
tracked_stracks (List[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (List[STrack]): List of lost tracks.
|
||||
removed_stracks (List[STrack]): List of removed tracks.
|
||||
frame_id (int): The current frame ID.
|
||||
args (namespace): Command-line arguments.
|
||||
args (Namespace): Command-line arguments.
|
||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||
kalman_filter (object): Kalman Filter object.
|
||||
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
||||
|
||||
Methods:
|
||||
update(results, img=None): Updates object tracker with new detections.
|
||||
|
|
@ -236,10 +259,27 @@ class BYTETracker:
|
|||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
||||
|
||||
Examples:
|
||||
Initialize BYTETracker and update with detection results
|
||||
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||
>>> results = yolo_model.detect(image)
|
||||
>>> tracked_objects = tracker.update(results)
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
"""
|
||||
Initialize a BYTETracker instance for object tracking.
|
||||
|
||||
Args:
|
||||
args (Namespace): Command-line arguments containing tracking parameters.
|
||||
frame_rate (int): Frame rate of the video sequence.
|
||||
|
||||
Examples:
|
||||
Initialize BYTETracker with command-line arguments and a frame rate of 30
|
||||
>>> args = Namespace(track_buffer=30)
|
||||
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||
"""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
|
@ -251,7 +291,7 @@ class BYTETracker:
|
|||
self.reset_id()
|
||||
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
"""Updates the tracker with new detections and returns the current list of tracked objects."""
|
||||
self.frame_id += 1
|
||||
activated_stracks = []
|
||||
refind_stracks = []
|
||||
|
|
@ -365,31 +405,31 @@ class BYTETracker:
|
|||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
"""Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
"""Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
||||
"""Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
if self.args.fuse_score:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
"""Predict the next states for multiple tracks using Kalman filter."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Resets the ID counter of STrack."""
|
||||
"""Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
||||
STrack.reset_id()
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
"""Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
|
@ -399,7 +439,7 @@ class BYTETracker:
|
|||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
"""Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
|
|
@ -414,20 +454,13 @@ class BYTETracker:
|
|||
|
||||
@staticmethod
|
||||
def sub_stracks(tlista, tlistb):
|
||||
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
||||
stracks = {t.track_id: t for t in tlista}
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
"""
|
||||
"""Filters out the stracks present in the second list from the first list."""
|
||||
track_ids_b = {t.track_id for t in tlistb}
|
||||
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IoU distance."""
|
||||
"""Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue