Update Tracker docstrings (#15469)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
da2797a182
commit
b7c5db94b4
10 changed files with 501 additions and 196 deletions
|
|
@ -81,7 +81,7 @@ CLI_HELP_MSG = f"""
|
||||||
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
|
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
|
||||||
yolo explorer data=data.yaml model=yolov8n.pt
|
yolo explorer data=data.yaml model=yolov8n.pt
|
||||||
|
|
||||||
6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8
|
6. Streamlit real-time webcam inference GUI
|
||||||
yolo streamlit-predict
|
yolo streamlit-predict
|
||||||
|
|
||||||
7. Run special commands:
|
7. Run special commands:
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,11 @@ class TrackState:
|
||||||
Tracked (int): State when the object is successfully tracked in subsequent frames.
|
Tracked (int): State when the object is successfully tracked in subsequent frames.
|
||||||
Lost (int): State when the object is no longer tracked.
|
Lost (int): State when the object is no longer tracked.
|
||||||
Removed (int): State when the object is removed from tracking.
|
Removed (int): State when the object is removed from tracking.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> state = TrackState.New
|
||||||
|
>>> if state == TrackState.New:
|
||||||
|
>>> print("Object is newly detected.")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
New = 0
|
New = 0
|
||||||
|
|
@ -33,13 +38,13 @@ class BaseTrack:
|
||||||
is_activated (bool): Flag indicating whether the track is currently active.
|
is_activated (bool): Flag indicating whether the track is currently active.
|
||||||
state (TrackState): Current state of the track.
|
state (TrackState): Current state of the track.
|
||||||
history (OrderedDict): Ordered history of the track's states.
|
history (OrderedDict): Ordered history of the track's states.
|
||||||
features (list): List of features extracted from the object for tracking.
|
features (List): List of features extracted from the object for tracking.
|
||||||
curr_feature (any): The current feature of the object being tracked.
|
curr_feature (Any): The current feature of the object being tracked.
|
||||||
score (float): The confidence score of the tracking.
|
score (float): The confidence score of the tracking.
|
||||||
start_frame (int): The frame number where tracking started.
|
start_frame (int): The frame number where tracking started.
|
||||||
frame_id (int): The most recent frame ID processed by the track.
|
frame_id (int): The most recent frame ID processed by the track.
|
||||||
time_since_update (int): Frames passed since the last update.
|
time_since_update (int): Frames passed since the last update.
|
||||||
location (tuple): The location of the object in the context of multi-camera tracking.
|
location (Tuple): The location of the object in the context of multi-camera tracking.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||||
|
|
@ -50,12 +55,26 @@ class BaseTrack:
|
||||||
mark_lost: Marks the track as lost.
|
mark_lost: Marks the track as lost.
|
||||||
mark_removed: Marks the track as removed.
|
mark_removed: Marks the track as removed.
|
||||||
reset_id: Resets the global track ID counter.
|
reset_id: Resets the global track ID counter.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize a new track and mark it as lost:
|
||||||
|
>>> track = BaseTrack()
|
||||||
|
>>> track.mark_lost()
|
||||||
|
>>> print(track.state) # Output: 2 (TrackState.Lost)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_count = 0
|
_count = 0
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initializes a new track with unique ID and foundational tracking attributes."""
|
"""
|
||||||
|
Initializes a new track with a unique ID and foundational tracking attributes.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize a new track
|
||||||
|
>>> track = BaseTrack()
|
||||||
|
>>> print(track.track_id)
|
||||||
|
0
|
||||||
|
"""
|
||||||
self.track_id = 0
|
self.track_id = 0
|
||||||
self.is_activated = False
|
self.is_activated = False
|
||||||
self.state = TrackState.New
|
self.state = TrackState.New
|
||||||
|
|
@ -70,36 +89,36 @@ class BaseTrack:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def end_frame(self):
|
def end_frame(self):
|
||||||
"""Return the last frame ID of the track."""
|
"""Returns the ID of the most recent frame where the object was tracked."""
|
||||||
return self.frame_id
|
return self.frame_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def next_id():
|
def next_id():
|
||||||
"""Increment and return the global track ID counter."""
|
"""Increment and return the next unique global track ID for object tracking."""
|
||||||
BaseTrack._count += 1
|
BaseTrack._count += 1
|
||||||
return BaseTrack._count
|
return BaseTrack._count
|
||||||
|
|
||||||
def activate(self, *args):
|
def activate(self, *args):
|
||||||
"""Abstract method to activate the track with provided arguments."""
|
"""Activates the track with provided arguments, initializing necessary attributes for tracking."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def predict(self):
|
def predict(self):
|
||||||
"""Abstract method to predict the next state of the track."""
|
"""Predicts the next state of the track based on the current state and tracking model."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def update(self, *args, **kwargs):
|
def update(self, *args, **kwargs):
|
||||||
"""Abstract method to update the track with new observations."""
|
"""Updates the track with new observations and data, modifying its state and attributes accordingly."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def mark_lost(self):
|
def mark_lost(self):
|
||||||
"""Mark the track as lost."""
|
"""Marks the track as lost by updating its state to TrackState.Lost."""
|
||||||
self.state = TrackState.Lost
|
self.state = TrackState.Lost
|
||||||
|
|
||||||
def mark_removed(self):
|
def mark_removed(self):
|
||||||
"""Mark the track as removed."""
|
"""Marks the track as removed by setting its state to TrackState.Removed."""
|
||||||
self.state = TrackState.Removed
|
self.state = TrackState.Removed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_id():
|
def reset_id():
|
||||||
"""Reset the global track ID counter."""
|
"""Reset the global track ID counter to its initial value."""
|
||||||
BaseTrack._count = 0
|
BaseTrack._count = 0
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@ class BOTrack(STrack):
|
||||||
"""
|
"""
|
||||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||||
|
|
||||||
|
This class extends the STrack class to include additional functionalities for object tracking, such as feature
|
||||||
|
smoothing, Kalman filter prediction, and reactivation of tracks.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||||
|
|
@ -34,16 +37,35 @@ class BOTrack(STrack):
|
||||||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||||
|
|
||||||
Usage:
|
Examples:
|
||||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
Create a BOTrack instance and update its features
|
||||||
bo_track.predict()
|
>>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
|
||||||
bo_track.update(new_track, frame_id)
|
>>> bo_track.predict()
|
||||||
|
>>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
|
||||||
|
>>> bo_track.update(new_track, frame_id=2)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shared_kalman = KalmanFilterXYWH()
|
shared_kalman = KalmanFilterXYWH()
|
||||||
|
|
||||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
"""
|
||||||
|
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
|
||||||
|
score (float): Confidence score of the detection.
|
||||||
|
cls (int): Class ID of the detected object.
|
||||||
|
feat (np.ndarray | None): Feature vector associated with the detection.
|
||||||
|
feat_history (int): Maximum length of the feature history deque.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
|
||||||
|
>>> tlwh = np.array([100, 50, 80, 120])
|
||||||
|
>>> score = 0.9
|
||||||
|
>>> cls = 1
|
||||||
|
>>> feat = np.random.rand(128)
|
||||||
|
>>> bo_track = BOTrack(tlwh, score, cls, feat)
|
||||||
|
"""
|
||||||
super().__init__(tlwh, score, cls)
|
super().__init__(tlwh, score, cls)
|
||||||
|
|
||||||
self.smooth_feat = None
|
self.smooth_feat = None
|
||||||
|
|
@ -54,7 +76,7 @@ class BOTrack(STrack):
|
||||||
self.alpha = 0.9
|
self.alpha = 0.9
|
||||||
|
|
||||||
def update_features(self, feat):
|
def update_features(self, feat):
|
||||||
"""Update features vector and smooth it using exponential moving average."""
|
"""Update the feature vector and apply exponential moving average smoothing."""
|
||||||
feat /= np.linalg.norm(feat)
|
feat /= np.linalg.norm(feat)
|
||||||
self.curr_feat = feat
|
self.curr_feat = feat
|
||||||
if self.smooth_feat is None:
|
if self.smooth_feat is None:
|
||||||
|
|
@ -65,7 +87,7 @@ class BOTrack(STrack):
|
||||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||||
|
|
||||||
def predict(self):
|
def predict(self):
|
||||||
"""Predicts the mean and covariance using Kalman filter."""
|
"""Predicts the object's future state using the Kalman filter to update its mean and covariance."""
|
||||||
mean_state = self.mean.copy()
|
mean_state = self.mean.copy()
|
||||||
if self.state != TrackState.Tracked:
|
if self.state != TrackState.Tracked:
|
||||||
mean_state[6] = 0
|
mean_state[6] = 0
|
||||||
|
|
@ -80,14 +102,14 @@ class BOTrack(STrack):
|
||||||
super().re_activate(new_track, frame_id, new_id)
|
super().re_activate(new_track, frame_id, new_id)
|
||||||
|
|
||||||
def update(self, new_track, frame_id):
|
def update(self, new_track, frame_id):
|
||||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
"""Updates the YOLOv8 instance with new track information and the current frame ID."""
|
||||||
if new_track.curr_feat is not None:
|
if new_track.curr_feat is not None:
|
||||||
self.update_features(new_track.curr_feat)
|
self.update_features(new_track.curr_feat)
|
||||||
super().update(new_track, frame_id)
|
super().update(new_track, frame_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tlwh(self):
|
def tlwh(self):
|
||||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
"""Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
|
||||||
if self.mean is None:
|
if self.mean is None:
|
||||||
return self._tlwh.copy()
|
return self._tlwh.copy()
|
||||||
ret = self.mean[:4].copy()
|
ret = self.mean[:4].copy()
|
||||||
|
|
@ -96,7 +118,7 @@ class BOTrack(STrack):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_predict(stracks):
|
def multi_predict(stracks):
|
||||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
"""Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
||||||
if len(stracks) <= 0:
|
if len(stracks) <= 0:
|
||||||
return
|
return
|
||||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||||
|
|
@ -111,12 +133,12 @@ class BOTrack(STrack):
|
||||||
stracks[i].covariance = cov
|
stracks[i].covariance = cov
|
||||||
|
|
||||||
def convert_coords(self, tlwh):
|
def convert_coords(self, tlwh):
|
||||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
"""Converts tlwh bounding box coordinates to xywh format."""
|
||||||
return self.tlwh_to_xywh(tlwh)
|
return self.tlwh_to_xywh(tlwh)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tlwh_to_xywh(tlwh):
|
def tlwh_to_xywh(tlwh):
|
||||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
||||||
ret = np.asarray(tlwh).copy()
|
ret = np.asarray(tlwh).copy()
|
||||||
ret[:2] += ret[2:] / 2
|
ret[:2] += ret[2:] / 2
|
||||||
return ret
|
return ret
|
||||||
|
|
@ -129,9 +151,9 @@ class BOTSORT(BYTETracker):
|
||||||
Attributes:
|
Attributes:
|
||||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||||
args (object): Parsed command-line arguments containing tracking parameters.
|
args (Any): Parsed command-line arguments containing tracking parameters.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||||
|
|
@ -139,17 +161,29 @@ class BOTSORT(BYTETracker):
|
||||||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||||
|
|
||||||
Usage:
|
Examples:
|
||||||
bot_sort = BOTSORT(args, frame_rate)
|
Initialize BOTSORT and process detections
|
||||||
bot_sort.init_track(dets, scores, cls, img)
|
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||||
bot_sort.multi_predict(tracks)
|
>>> bot_sort.init_track(dets, scores, cls, img)
|
||||||
|
>>> bot_sort.multi_predict(tracks)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, frame_rate=30):
|
def __init__(self, args, frame_rate=30):
|
||||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
"""
|
||||||
|
Initialize YOLOv8 object with ReID module and GMC algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (object): Parsed command-line arguments containing tracking parameters.
|
||||||
|
frame_rate (int): Frame rate of the video being processed.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize BOTSORT with command-line arguments and a specified frame rate:
|
||||||
|
>>> args = parse_args()
|
||||||
|
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||||
|
"""
|
||||||
super().__init__(args, frame_rate)
|
super().__init__(args, frame_rate)
|
||||||
# ReID module
|
# ReID module
|
||||||
self.proximity_thresh = args.proximity_thresh
|
self.proximity_thresh = args.proximity_thresh
|
||||||
|
|
@ -161,11 +195,11 @@ class BOTSORT(BYTETracker):
|
||||||
self.gmc = GMC(method=args.gmc_method)
|
self.gmc = GMC(method=args.gmc_method)
|
||||||
|
|
||||||
def get_kalmanfilter(self):
|
def get_kalmanfilter(self):
|
||||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
"""Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
||||||
return KalmanFilterXYWH()
|
return KalmanFilterXYWH()
|
||||||
|
|
||||||
def init_track(self, dets, scores, cls, img=None):
|
def init_track(self, dets, scores, cls, img=None):
|
||||||
"""Initialize track with detections, scores, and classes."""
|
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
||||||
if len(dets) == 0:
|
if len(dets) == 0:
|
||||||
return []
|
return []
|
||||||
if self.args.with_reid and self.encoder is not None:
|
if self.args.with_reid and self.encoder is not None:
|
||||||
|
|
@ -175,7 +209,7 @@ class BOTSORT(BYTETracker):
|
||||||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||||
|
|
||||||
def get_dists(self, tracks, detections):
|
def get_dists(self, tracks, detections):
|
||||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
"""Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
|
||||||
dists = matching.iou_distance(tracks, detections)
|
dists = matching.iou_distance(tracks, detections)
|
||||||
dists_mask = dists > self.proximity_thresh
|
dists_mask = dists > self.proximity_thresh
|
||||||
|
|
||||||
|
|
@ -190,10 +224,10 @@ class BOTSORT(BYTETracker):
|
||||||
return dists
|
return dists
|
||||||
|
|
||||||
def multi_predict(self, tracks):
|
def multi_predict(self, tracks):
|
||||||
"""Predict and track multiple objects with YOLOv8 model."""
|
"""Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
||||||
BOTrack.multi_predict(tracks)
|
BOTrack.multi_predict(tracks)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset tracker."""
|
"""Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
||||||
super().reset()
|
super().reset()
|
||||||
self.gmc.reset_params()
|
self.gmc.reset_params()
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class STrack(BaseTrack):
|
||||||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||||
score (float): Confidence score of the track.
|
score (float): Confidence score of the track.
|
||||||
tracklet_len (int): Length of the tracklet.
|
tracklet_len (int): Length of the tracklet.
|
||||||
cls (any): Class label for the object.
|
cls (Any): Class label for the object.
|
||||||
idx (int): Index or identifier for the object.
|
idx (int): Index or identifier for the object.
|
||||||
frame_id (int): Current frame ID.
|
frame_id (int): Current frame ID.
|
||||||
start_frame (int): Frame where the object was first detected.
|
start_frame (int): Frame where the object was first detected.
|
||||||
|
|
@ -39,12 +39,31 @@ class STrack(BaseTrack):
|
||||||
update(new_track, frame_id): Update the state of a matched track.
|
update(new_track, frame_id): Update the state of a matched track.
|
||||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize and activate a new track
|
||||||
|
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
|
||||||
|
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
shared_kalman = KalmanFilterXYAH()
|
shared_kalman = KalmanFilterXYAH()
|
||||||
|
|
||||||
def __init__(self, xywh, score, cls):
|
def __init__(self, xywh, score, cls):
|
||||||
"""Initialize new STrack instance."""
|
"""
|
||||||
|
Initialize a new STrack instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
|
||||||
|
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
||||||
|
score (float): Confidence score of the detection.
|
||||||
|
cls (Any): Class label for the detected object.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
|
||||||
|
>>> score = 0.9
|
||||||
|
>>> cls = 'person'
|
||||||
|
>>> track = STrack(xywh, score, cls)
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# xywh+idx or xywha+idx
|
# xywh+idx or xywha+idx
|
||||||
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
||||||
|
|
@ -60,7 +79,7 @@ class STrack(BaseTrack):
|
||||||
self.angle = xywh[4] if len(xywh) == 6 else None
|
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||||
|
|
||||||
def predict(self):
|
def predict(self):
|
||||||
"""Predicts mean and covariance using Kalman filter."""
|
"""Predicts the next state (mean and covariance) of the object using the Kalman filter."""
|
||||||
mean_state = self.mean.copy()
|
mean_state = self.mean.copy()
|
||||||
if self.state != TrackState.Tracked:
|
if self.state != TrackState.Tracked:
|
||||||
mean_state[7] = 0
|
mean_state[7] = 0
|
||||||
|
|
@ -68,7 +87,7 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_predict(stracks):
|
def multi_predict(stracks):
|
||||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
||||||
if len(stracks) <= 0:
|
if len(stracks) <= 0:
|
||||||
return
|
return
|
||||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||||
|
|
@ -83,7 +102,7 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||||
"""Update state tracks positions and covariances using a homography matrix."""
|
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
||||||
if len(stracks) > 0:
|
if len(stracks) > 0:
|
||||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||||
|
|
@ -101,7 +120,7 @@ class STrack(BaseTrack):
|
||||||
stracks[i].covariance = cov
|
stracks[i].covariance = cov
|
||||||
|
|
||||||
def activate(self, kalman_filter, frame_id):
|
def activate(self, kalman_filter, frame_id):
|
||||||
"""Start a new tracklet."""
|
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
||||||
self.kalman_filter = kalman_filter
|
self.kalman_filter = kalman_filter
|
||||||
self.track_id = self.next_id()
|
self.track_id = self.next_id()
|
||||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||||
|
|
@ -114,7 +133,7 @@ class STrack(BaseTrack):
|
||||||
self.start_frame = frame_id
|
self.start_frame = frame_id
|
||||||
|
|
||||||
def re_activate(self, new_track, frame_id, new_id=False):
|
def re_activate(self, new_track, frame_id, new_id=False):
|
||||||
"""Reactivates a previously lost track with a new detection."""
|
"""Reactivates a previously lost track using new detection data and updates its state and attributes."""
|
||||||
self.mean, self.covariance = self.kalman_filter.update(
|
self.mean, self.covariance = self.kalman_filter.update(
|
||||||
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
||||||
)
|
)
|
||||||
|
|
@ -136,6 +155,12 @@ class STrack(BaseTrack):
|
||||||
Args:
|
Args:
|
||||||
new_track (STrack): The new track containing updated information.
|
new_track (STrack): The new track containing updated information.
|
||||||
frame_id (int): The ID of the current frame.
|
frame_id (int): The ID of the current frame.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Update the state of a track with new detection information
|
||||||
|
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
|
||||||
|
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
|
||||||
|
>>> track.update(new_track, 2)
|
||||||
"""
|
"""
|
||||||
self.frame_id = frame_id
|
self.frame_id = frame_id
|
||||||
self.tracklet_len += 1
|
self.tracklet_len += 1
|
||||||
|
|
@ -158,7 +183,7 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tlwh(self):
|
def tlwh(self):
|
||||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
"""Returns the bounding box in top-left-width-height format from the current state estimate."""
|
||||||
if self.mean is None:
|
if self.mean is None:
|
||||||
return self._tlwh.copy()
|
return self._tlwh.copy()
|
||||||
ret = self.mean[:4].copy()
|
ret = self.mean[:4].copy()
|
||||||
|
|
@ -168,16 +193,14 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def xyxy(self):
|
def xyxy(self):
|
||||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
"""Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
||||||
ret = self.tlwh.copy()
|
ret = self.tlwh.copy()
|
||||||
ret[2:] += ret[:2]
|
ret[2:] += ret[:2]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tlwh_to_xyah(tlwh):
|
def tlwh_to_xyah(tlwh):
|
||||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
||||||
height.
|
|
||||||
"""
|
|
||||||
ret = np.asarray(tlwh).copy()
|
ret = np.asarray(tlwh).copy()
|
||||||
ret[:2] += ret[2:] / 2
|
ret[:2] += ret[2:] / 2
|
||||||
ret[2] /= ret[3]
|
ret[2] /= ret[3]
|
||||||
|
|
@ -185,14 +208,14 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def xywh(self):
|
def xywh(self):
|
||||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
"""Returns the current position of the bounding box in (center x, center y, width, height) format."""
|
||||||
ret = np.asarray(self.tlwh).copy()
|
ret = np.asarray(self.tlwh).copy()
|
||||||
ret[:2] += ret[2:] / 2
|
ret[:2] += ret[2:] / 2
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def xywha(self):
|
def xywha(self):
|
||||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
"""Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
||||||
if self.angle is None:
|
if self.angle is None:
|
||||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||||
return self.xywh
|
return self.xywh
|
||||||
|
|
@ -200,12 +223,12 @@ class STrack(BaseTrack):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def result(self):
|
def result(self):
|
||||||
"""Get current tracking results."""
|
"""Returns the current tracking results in the appropriate bounding box format."""
|
||||||
coords = self.xyxy if self.angle is None else self.xywha
|
coords = self.xyxy if self.angle is None else self.xywha
|
||||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
"""Returns a string representation of the STrack object including start frame, end frame, and track ID."""
|
||||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -213,18 +236,18 @@ class BYTETracker:
|
||||||
"""
|
"""
|
||||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||||
|
|
||||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
|
||||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
|
||||||
predicting the new object locations, and performs data association.
|
the new object locations, and performs data association.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
tracked_stracks (List[STrack]): List of successfully activated tracks.
|
||||||
lost_stracks (list[STrack]): List of lost tracks.
|
lost_stracks (List[STrack]): List of lost tracks.
|
||||||
removed_stracks (list[STrack]): List of removed tracks.
|
removed_stracks (List[STrack]): List of removed tracks.
|
||||||
frame_id (int): The current frame ID.
|
frame_id (int): The current frame ID.
|
||||||
args (namespace): Command-line arguments.
|
args (Namespace): Command-line arguments.
|
||||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||||
kalman_filter (object): Kalman Filter object.
|
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
update(results, img=None): Updates object tracker with new detections.
|
update(results, img=None): Updates object tracker with new detections.
|
||||||
|
|
@ -236,10 +259,27 @@ class BYTETracker:
|
||||||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize BYTETracker and update with detection results
|
||||||
|
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||||
|
>>> results = yolo_model.detect(image)
|
||||||
|
>>> tracked_objects = tracker.update(results)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args, frame_rate=30):
|
def __init__(self, args, frame_rate=30):
|
||||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
"""
|
||||||
|
Initialize a BYTETracker instance for object tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args (Namespace): Command-line arguments containing tracking parameters.
|
||||||
|
frame_rate (int): Frame rate of the video sequence.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize BYTETracker with command-line arguments and a frame rate of 30
|
||||||
|
>>> args = Namespace(track_buffer=30)
|
||||||
|
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||||
|
"""
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
self.tracked_stracks = [] # type: list[STrack]
|
||||||
self.lost_stracks = [] # type: list[STrack]
|
self.lost_stracks = [] # type: list[STrack]
|
||||||
self.removed_stracks = [] # type: list[STrack]
|
self.removed_stracks = [] # type: list[STrack]
|
||||||
|
|
@ -251,7 +291,7 @@ class BYTETracker:
|
||||||
self.reset_id()
|
self.reset_id()
|
||||||
|
|
||||||
def update(self, results, img=None):
|
def update(self, results, img=None):
|
||||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
"""Updates the tracker with new detections and returns the current list of tracked objects."""
|
||||||
self.frame_id += 1
|
self.frame_id += 1
|
||||||
activated_stracks = []
|
activated_stracks = []
|
||||||
refind_stracks = []
|
refind_stracks = []
|
||||||
|
|
@ -365,31 +405,31 @@ class BYTETracker:
|
||||||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||||
|
|
||||||
def get_kalmanfilter(self):
|
def get_kalmanfilter(self):
|
||||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
"""Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
||||||
return KalmanFilterXYAH()
|
return KalmanFilterXYAH()
|
||||||
|
|
||||||
def init_track(self, dets, scores, cls, img=None):
|
def init_track(self, dets, scores, cls, img=None):
|
||||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
"""Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
||||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||||
|
|
||||||
def get_dists(self, tracks, detections):
|
def get_dists(self, tracks, detections):
|
||||||
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
"""Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
|
||||||
dists = matching.iou_distance(tracks, detections)
|
dists = matching.iou_distance(tracks, detections)
|
||||||
if self.args.fuse_score:
|
if self.args.fuse_score:
|
||||||
dists = matching.fuse_score(dists, detections)
|
dists = matching.fuse_score(dists, detections)
|
||||||
return dists
|
return dists
|
||||||
|
|
||||||
def multi_predict(self, tracks):
|
def multi_predict(self, tracks):
|
||||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
"""Predict the next states for multiple tracks using Kalman filter."""
|
||||||
STrack.multi_predict(tracks)
|
STrack.multi_predict(tracks)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_id():
|
def reset_id():
|
||||||
"""Resets the ID counter of STrack."""
|
"""Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
||||||
STrack.reset_id()
|
STrack.reset_id()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset tracker."""
|
"""Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
||||||
self.tracked_stracks = [] # type: list[STrack]
|
self.tracked_stracks = [] # type: list[STrack]
|
||||||
self.lost_stracks = [] # type: list[STrack]
|
self.lost_stracks = [] # type: list[STrack]
|
||||||
self.removed_stracks = [] # type: list[STrack]
|
self.removed_stracks = [] # type: list[STrack]
|
||||||
|
|
@ -399,7 +439,7 @@ class BYTETracker:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def joint_stracks(tlista, tlistb):
|
def joint_stracks(tlista, tlistb):
|
||||||
"""Combine two lists of stracks into a single one."""
|
"""Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
||||||
exists = {}
|
exists = {}
|
||||||
res = []
|
res = []
|
||||||
for t in tlista:
|
for t in tlista:
|
||||||
|
|
@ -414,20 +454,13 @@ class BYTETracker:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sub_stracks(tlista, tlistb):
|
def sub_stracks(tlista, tlistb):
|
||||||
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
"""Filters out the stracks present in the second list from the first list."""
|
||||||
stracks = {t.track_id: t for t in tlista}
|
|
||||||
for t in tlistb:
|
|
||||||
tid = t.track_id
|
|
||||||
if stracks.get(tid, 0):
|
|
||||||
del stracks[tid]
|
|
||||||
return list(stracks.values())
|
|
||||||
"""
|
|
||||||
track_ids_b = {t.track_id for t in tlistb}
|
track_ids_b = {t.track_id for t in tlistb}
|
||||||
return [t for t in tlista if t.track_id not in track_ids_b]
|
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_duplicate_stracks(stracksa, stracksb):
|
def remove_duplicate_stracks(stracksa, stracksb):
|
||||||
"""Remove duplicate stracks with non-maximum IoU distance."""
|
"""Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
||||||
pdist = matching.iou_distance(stracksa, stracksb)
|
pdist = matching.iou_distance(stracksa, stracksb)
|
||||||
pairs = np.where(pdist < 0.15)
|
pairs = np.where(pdist < 0.15)
|
||||||
dupa, dupb = [], []
|
dupa, dupb = [], []
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictor (object): The predictor object to initialize trackers for.
|
predictor (object): The predictor object to initialize trackers for.
|
||||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
persist (bool): Whether to persist the trackers if they already exist.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize trackers for a predictor object:
|
||||||
|
>>> predictor = SomePredictorClass()
|
||||||
|
>>> on_predict_start(predictor, persist=True)
|
||||||
"""
|
"""
|
||||||
if hasattr(predictor, "trackers") and persist:
|
if hasattr(predictor, "trackers") and persist:
|
||||||
return
|
return
|
||||||
|
|
@ -51,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictor (object): The predictor object containing the predictions.
|
predictor (object): The predictor object containing the predictions.
|
||||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
persist (bool): Whether to persist the trackers if they already exist.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Postprocess predictions and update with tracking
|
||||||
|
>>> predictor = YourPredictorClass()
|
||||||
|
>>> on_predict_postprocess_end(predictor, persist=True)
|
||||||
"""
|
"""
|
||||||
path, im0s = predictor.batch[:2]
|
path, im0s = predictor.batch[:2]
|
||||||
|
|
||||||
|
|
@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None:
|
||||||
Args:
|
Args:
|
||||||
model (object): The model object to register tracking callbacks for.
|
model (object): The model object to register tracking callbacks for.
|
||||||
persist (bool): Whether to persist the trackers if they already exist.
|
persist (bool): Whether to persist the trackers if they already exist.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Register tracking callbacks to a YOLO model
|
||||||
|
>>> model = YOLOModel()
|
||||||
|
>>> register_tracker(model, persist=True)
|
||||||
"""
|
"""
|
||||||
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
||||||
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
||||||
|
|
|
||||||
|
|
@ -19,27 +19,39 @@ class GMC:
|
||||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||||
downscale (int): Factor by which to downscale the frames for processing.
|
downscale (int): Factor by which to downscale the frames for processing.
|
||||||
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
||||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
prevKeyPoints (List): Stores the keypoints from the previous frame.
|
||||||
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
||||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
__init__: Initializes a GMC object with the specified method and downscale factor.
|
||||||
and downscale factor.
|
apply: Applies the chosen method to a raw frame and optionally uses provided detections.
|
||||||
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
applyEcc: Applies the ECC algorithm to a raw frame.
|
||||||
provided detections.
|
applyFeatures: Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||||
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
applySparseOptFlow: Applies the Sparse Optical Flow method to a raw frame.
|
||||||
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
reset_params: Resets the internal parameters of the GMC object.
|
||||||
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
|
||||||
|
Examples:
|
||||||
|
Create a GMC object and apply it to a frame
|
||||||
|
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
||||||
|
>>> frame = np.array([[1, 2, 3], [4, 5, 6]])
|
||||||
|
>>> processed_frame = gmc.apply(frame)
|
||||||
|
>>> print(processed_frame)
|
||||||
|
array([[1, 2, 3],
|
||||||
|
[4, 5, 6]])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize a video tracker with specified parameters.
|
Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||||
downscale (int): Downscale factor for processing frames.
|
downscale (int): Downscale factor for processing frames.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
|
||||||
|
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -79,20 +91,21 @@ class GMC:
|
||||||
|
|
||||||
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||||
"""
|
"""
|
||||||
Apply object detection on a raw frame using specified method.
|
Apply object detection on a raw frame using the specified method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_frame (np.ndarray): The raw frame to be processed.
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||||
detections (list): List of detections to be used in the processing.
|
detections (List | None): List of detections to be used in the processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Processed frame.
|
(np.ndarray): Processed frame with applied object detection.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> gmc = GMC()
|
>>> gmc = GMC(method='sparseOptFlow')
|
||||||
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
|
>>> raw_frame = np.random.rand(480, 640, 3)
|
||||||
array([[1, 2, 3],
|
>>> processed_frame = gmc.apply(raw_frame)
|
||||||
[4, 5, 6]])
|
>>> print(processed_frame.shape)
|
||||||
|
(480, 640, 3)
|
||||||
"""
|
"""
|
||||||
if self.method in {"orb", "sift"}:
|
if self.method in {"orb", "sift"}:
|
||||||
return self.applyFeatures(raw_frame, detections)
|
return self.applyFeatures(raw_frame, detections)
|
||||||
|
|
@ -105,19 +118,20 @@ class GMC:
|
||||||
|
|
||||||
def applyEcc(self, raw_frame: np.array) -> np.array:
|
def applyEcc(self, raw_frame: np.array) -> np.array:
|
||||||
"""
|
"""
|
||||||
Apply ECC algorithm to a raw frame.
|
Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_frame (np.ndarray): The raw frame to be processed.
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Processed frame.
|
(np.ndarray): The processed frame with the applied ECC transformation.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> gmc = GMC()
|
>>> gmc = GMC(method='ecc')
|
||||||
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
|
>>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
||||||
array([[1, 2, 3],
|
>>> print(processed_frame)
|
||||||
[4, 5, 6]])
|
[[1. 0. 0.]
|
||||||
|
[0. 1. 0.]]
|
||||||
"""
|
"""
|
||||||
height, width, _ = raw_frame.shape
|
height, width, _ = raw_frame.shape
|
||||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
@ -127,8 +141,6 @@ class GMC:
|
||||||
if self.downscale > 1.0:
|
if self.downscale > 1.0:
|
||||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||||
width = width // self.downscale
|
|
||||||
height = height // self.downscale
|
|
||||||
|
|
||||||
# Handle first frame
|
# Handle first frame
|
||||||
if not self.initializedFirstFrame:
|
if not self.initializedFirstFrame:
|
||||||
|
|
@ -154,17 +166,18 @@ class GMC:
|
||||||
Apply feature-based methods like ORB or SIFT to a raw frame.
|
Apply feature-based methods like ORB or SIFT to a raw frame.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_frame (np.ndarray): The raw frame to be processed.
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||||
detections (list): List of detections to be used in the processing.
|
detections (List | None): List of detections to be used in the processing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Processed frame.
|
(np.ndarray): Processed frame.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> gmc = GMC()
|
>>> gmc = GMC(method='orb')
|
||||||
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
|
>>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||||
array([[1, 2, 3],
|
>>> processed_frame = gmc.applyFeatures(raw_frame)
|
||||||
[4, 5, 6]])
|
>>> print(processed_frame.shape)
|
||||||
|
(2, 3)
|
||||||
"""
|
"""
|
||||||
height, width, _ = raw_frame.shape
|
height, width, _ = raw_frame.shape
|
||||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
@ -296,16 +309,17 @@ class GMC:
|
||||||
Apply Sparse Optical Flow method to a raw frame.
|
Apply Sparse Optical Flow method to a raw frame.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
raw_frame (np.ndarray): The raw frame to be processed.
|
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Processed frame.
|
(np.ndarray): Processed frame with shape (2, 3).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> gmc = GMC()
|
>>> gmc = GMC()
|
||||||
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
|
>>> result = gmc.applySparseOptFlow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
||||||
array([[1, 2, 3],
|
>>> print(result)
|
||||||
[4, 5, 6]])
|
[[1. 0. 0.]
|
||||||
|
[0. 1. 0.]]
|
||||||
"""
|
"""
|
||||||
height, width, _ = raw_frame.shape
|
height, width, _ = raw_frame.shape
|
||||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
@ -356,7 +370,7 @@ class GMC:
|
||||||
return H
|
return H
|
||||||
|
|
||||||
def reset_params(self) -> None:
|
def reset_params(self) -> None:
|
||||||
"""Reset parameters."""
|
"""Reset the internal parameters including previous frame, keypoints, and descriptors."""
|
||||||
self.prevFrame = None
|
self.prevFrame = None
|
||||||
self.prevKeyPoints = None
|
self.prevKeyPoints = None
|
||||||
self.prevDescriptors = None
|
self.prevDescriptors = None
|
||||||
|
|
|
||||||
|
|
@ -6,17 +6,49 @@ import scipy.linalg
|
||||||
|
|
||||||
class KalmanFilterXYAH:
|
class KalmanFilterXYAH:
|
||||||
"""
|
"""
|
||||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
|
||||||
|
|
||||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
|
||||||
ratio a, height h, and their respective velocities.
|
(x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
|
||||||
|
respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
|
||||||
|
taken as a direct observation of the state space (linear observation model).
|
||||||
|
|
||||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
Attributes:
|
||||||
observation of the state space (linear observation model).
|
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
|
||||||
|
_update_mat (np.ndarray): The update matrix for the Kalman filter.
|
||||||
|
_std_weight_position (float): Standard deviation weight for position.
|
||||||
|
_std_weight_velocity (float): Standard deviation weight for velocity.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
initiate: Creates a track from an unassociated measurement.
|
||||||
|
predict: Runs the Kalman filter prediction step.
|
||||||
|
project: Projects the state distribution to measurement space.
|
||||||
|
multi_predict: Runs the Kalman filter prediction step (vectorized version).
|
||||||
|
update: Runs the Kalman filter correction step.
|
||||||
|
gating_distance: Computes the gating distance between state distribution and measurements.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize the Kalman filter and create a track from a measurement
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> measurement = np.array([100, 200, 1.5, 50])
|
||||||
|
>>> mean, covariance = kf.initiate(measurement)
|
||||||
|
>>> print(mean)
|
||||||
|
>>> print(covariance)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
"""
|
||||||
|
Initialize Kalman filter model matrices with motion and observation uncertainty weights.
|
||||||
|
|
||||||
|
The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
|
||||||
|
represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
|
||||||
|
velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
|
||||||
|
observation model for bounding box location.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Initialize a Kalman filter for tracking:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
"""
|
||||||
ndim, dt = 4, 1.0
|
ndim, dt = 4, 1.0
|
||||||
|
|
||||||
# Create Kalman filter model matrices
|
# Create Kalman filter model matrices
|
||||||
|
|
@ -32,15 +64,20 @@ class KalmanFilterXYAH:
|
||||||
|
|
||||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||||
"""
|
"""
|
||||||
Create track from unassociated measurement.
|
Create a track from an unassociated measurement.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
|
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
|
||||||
and height h.
|
and height h.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
(tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional)
|
||||||
of the new track. Unobserved velocities are initialized to 0 mean.
|
of the new track. Unobserved velocities are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> measurement = np.array([100, 50, 1.5, 200])
|
||||||
|
>>> mean, covariance = kf.initiate(measurement)
|
||||||
"""
|
"""
|
||||||
mean_pos = measurement
|
mean_pos = measurement
|
||||||
mean_vel = np.zeros_like(mean_pos)
|
mean_vel = np.zeros_like(mean_pos)
|
||||||
|
|
@ -64,12 +101,18 @@ class KalmanFilterXYAH:
|
||||||
Run Kalman filter prediction step.
|
Run Kalman filter prediction step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
|
||||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||||
velocities are initialized to 0 mean.
|
velocities are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[3],
|
self._std_weight_position * mean[3],
|
||||||
|
|
@ -100,6 +143,12 @@ class KalmanFilterXYAH:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> projected_mean, projected_covariance = kf.project(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std = [
|
std = [
|
||||||
self._std_weight_position * mean[3],
|
self._std_weight_position * mean[3],
|
||||||
|
|
@ -115,15 +164,21 @@ class KalmanFilterXYAH:
|
||||||
|
|
||||||
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||||
"""
|
"""
|
||||||
Run Kalman filter prediction step (Vectorized version).
|
Run Kalman filter prediction step for multiple object states (Vectorized version).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
(tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states.
|
||||||
velocities are initialized to 0 mean.
|
The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities
|
||||||
|
are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mean = np.random.rand(10, 8) # 10 object states
|
||||||
|
>>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
|
||||||
|
>>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[:, 3],
|
self._std_weight_position * mean[:, 3],
|
||||||
|
|
@ -160,6 +215,13 @@ class KalmanFilterXYAH:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> measurement = np.array([1, 1, 1, 1])
|
||||||
|
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
|
||||||
"""
|
"""
|
||||||
projected_mean, projected_cov = self.project(mean, covariance)
|
projected_mean, projected_cov = self.project(mean, covariance)
|
||||||
|
|
||||||
|
|
@ -182,23 +244,31 @@ class KalmanFilterXYAH:
|
||||||
metric: str = "maha",
|
metric: str = "maha",
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
Compute gating distance between state distribution and measurements.
|
||||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
|
||||||
otherwise 2.
|
A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
|
||||||
|
distribution has 4 degrees of freedom, otherwise 2.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (ndarray): Mean vector over the state distribution (8 dimensional).
|
mean (ndarray): Mean vector over the state distribution (8 dimensional).
|
||||||
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
|
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
|
||||||
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
|
measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
|
||||||
is the bounding box center position, a the aspect ratio, and h the height.
|
bounding box center position, a the aspect ratio, and h the height.
|
||||||
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
|
only_position (bool): If True, distance computation is done with respect to box center position only.
|
||||||
center position only. Defaults to False.
|
metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
|
||||||
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
|
Euclidean distance and 'maha' for the squared Mahalanobis distance.
|
||||||
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
|
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
|
||||||
(mean, covariance) and `measurements[i]`.
|
(mean, covariance) and `measurements[i]`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Compute gating distance using Mahalanobis metric:
|
||||||
|
>>> kf = KalmanFilterXYAH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
|
||||||
|
>>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha')
|
||||||
"""
|
"""
|
||||||
mean, covariance = self.project(mean, covariance)
|
mean, covariance = self.project(mean, covariance)
|
||||||
if only_position:
|
if only_position:
|
||||||
|
|
@ -218,13 +288,33 @@ class KalmanFilterXYAH:
|
||||||
|
|
||||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
"""
|
"""
|
||||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
|
||||||
|
|
||||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
|
||||||
w, height h, and their respective velocities.
|
(x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
|
||||||
|
The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
|
||||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
|
||||||
observation of the state space (linear observation model).
|
observation of the state space (linear observation model).
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
|
||||||
|
_update_mat (np.ndarray): The update matrix for the Kalman filter.
|
||||||
|
_std_weight_position (float): Standard deviation weight for position.
|
||||||
|
_std_weight_velocity (float): Standard deviation weight for velocity.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
initiate: Creates a track from an unassociated measurement.
|
||||||
|
predict: Runs the Kalman filter prediction step.
|
||||||
|
project: Projects the state distribution to measurement space.
|
||||||
|
multi_predict: Runs the Kalman filter prediction step in a vectorized manner.
|
||||||
|
update: Runs the Kalman filter correction step.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Create a Kalman filter and initialize a track
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> measurement = np.array([100, 50, 20, 40])
|
||||||
|
>>> mean, covariance = kf.initiate(measurement)
|
||||||
|
>>> print(mean)
|
||||||
|
>>> print(covariance)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||||
|
|
@ -237,6 +327,22 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
||||||
of the new track. Unobserved velocities are initialized to 0 mean.
|
of the new track. Unobserved velocities are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> measurement = np.array([100, 50, 20, 40])
|
||||||
|
>>> mean, covariance = kf.initiate(measurement)
|
||||||
|
>>> print(mean)
|
||||||
|
[100. 50. 20. 40. 0. 0. 0. 0.]
|
||||||
|
>>> print(covariance)
|
||||||
|
[[ 4. 0. 0. 0. 0. 0. 0. 0.]
|
||||||
|
[ 0. 4. 0. 0. 0. 0. 0. 0.]
|
||||||
|
[ 0. 0. 4. 0. 0. 0. 0. 0.]
|
||||||
|
[ 0. 0. 0. 4. 0. 0. 0. 0.]
|
||||||
|
[ 0. 0. 0. 0. 0.25 0. 0. 0.]
|
||||||
|
[ 0. 0. 0. 0. 0. 0.25 0. 0.]
|
||||||
|
[ 0. 0. 0. 0. 0. 0. 0.25 0.]
|
||||||
|
[ 0. 0. 0. 0. 0. 0. 0. 0.25]]
|
||||||
"""
|
"""
|
||||||
mean_pos = measurement
|
mean_pos = measurement
|
||||||
mean_vel = np.zeros_like(mean_pos)
|
mean_vel = np.zeros_like(mean_pos)
|
||||||
|
|
@ -260,12 +366,18 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
Run Kalman filter prediction step.
|
Run Kalman filter prediction step.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
|
||||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||||
velocities are initialized to 0 mean.
|
velocities are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[2],
|
self._std_weight_position * mean[2],
|
||||||
|
|
@ -296,6 +408,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> projected_mean, projected_cov = kf.project(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std = [
|
std = [
|
||||||
self._std_weight_position * mean[2],
|
self._std_weight_position * mean[2],
|
||||||
|
|
@ -320,6 +438,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||||
velocities are initialized to 0 mean.
|
velocities are initialized to 0 mean.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
|
||||||
|
>>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
|
||||||
"""
|
"""
|
||||||
std_pos = [
|
std_pos = [
|
||||||
self._std_weight_position * mean[:, 2],
|
self._std_weight_position * mean[:, 2],
|
||||||
|
|
@ -356,5 +480,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> kf = KalmanFilterXYWH()
|
||||||
|
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||||
|
>>> covariance = np.eye(8)
|
||||||
|
>>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
|
||||||
|
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
|
||||||
"""
|
"""
|
||||||
return super().update(mean, covariance, measurement)
|
return super().update(mean, covariance, measurement)
|
||||||
|
|
|
||||||
|
|
@ -19,18 +19,23 @@ except (ImportError, AssertionError, AttributeError):
|
||||||
|
|
||||||
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
|
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
|
||||||
"""
|
"""
|
||||||
Perform linear assignment using scipy or lap.lapjv.
|
Perform linear assignment using either the scipy or lap.lapjv method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
||||||
thresh (float): Threshold for considering an assignment valid.
|
thresh (float): Threshold for considering an assignment valid.
|
||||||
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
|
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple with:
|
(tuple): A tuple containing:
|
||||||
- matched indices
|
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
||||||
- unmatched indices from 'a'
|
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
||||||
- unmatched indices from 'b'
|
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
>>> thresh = 5.0
|
||||||
|
>>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cost_matrix.size == 0:
|
if cost_matrix.size == 0:
|
||||||
|
|
@ -68,6 +73,12 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Cost matrix computed based on IoU.
|
(np.ndarray): Cost matrix computed based on IoU.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Compute IoU distance between two sets of tracks
|
||||||
|
>>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
|
||||||
|
>>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
|
||||||
|
>>> cost_matrix = iou_distance(atracks, btracks)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
||||||
|
|
@ -98,12 +109,19 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
|
||||||
Compute distance between tracks and detections based on embeddings.
|
Compute distance between tracks and detections based on embeddings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tracks (list[STrack]): List of tracks.
|
tracks (list[STrack]): List of tracks, where each track contains embedding features.
|
||||||
detections (list[BaseTrack]): List of detections.
|
detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
|
||||||
metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
|
metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Cost matrix computed based on embeddings.
|
(np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
|
||||||
|
and M is the number of detections.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Compute the embedding distance between tracks and detections using cosine metric
|
||||||
|
>>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
|
||||||
|
>>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
|
||||||
|
>>> cost_matrix = embedding_distance(tracks, detections, metric='cosine')
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||||
|
|
@ -122,11 +140,17 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
||||||
Fuses cost matrix with detection scores to produce a single similarity matrix.
|
Fuses cost matrix with detection scores to produce a single similarity matrix.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
||||||
detections (list[BaseTrack]): List of detections with scores.
|
detections (list[BaseTrack]): List of detections, each containing a score attribute.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(np.ndarray): Fused similarity matrix.
|
(np.ndarray): Fused similarity matrix with shape (N, M).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Fuse a cost matrix with detection scores
|
||||||
|
>>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
|
||||||
|
>>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
|
||||||
|
>>> fused_matrix = fuse_score(cost_matrix, detections)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cost_matrix.size == 0:
|
if cost_matrix.size == 0:
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ PYTHON_VERSION = platform.python_version()
|
||||||
TORCH_VERSION = torch.__version__
|
TORCH_VERSION = torch.__version__
|
||||||
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
||||||
HELP_MSG = """
|
HELP_MSG = """
|
||||||
Usage examples for running Ultralytics YOLO:
|
Examples for running Ultralytics:
|
||||||
|
|
||||||
1. Install the ultralytics package:
|
1. Install the ultralytics package:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,19 +11,44 @@ from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class WorkingDirectory(contextlib.ContextDecorator):
|
class WorkingDirectory(contextlib.ContextDecorator):
|
||||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
"""
|
||||||
|
A context manager and decorator for temporarily changing the working directory.
|
||||||
|
|
||||||
|
This class allows for the temporary change of the working directory using a context manager or decorator.
|
||||||
|
It ensures that the original working directory is restored after the context or decorated function completes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dir (Path): The new directory to switch to.
|
||||||
|
cwd (Path): The original current working directory before the switch.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
__enter__: Changes the current directory to the specified directory.
|
||||||
|
__exit__: Restores the original working directory on context exit.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Using as a context manager:
|
||||||
|
>>> with WorkingDirectory('/path/to/new/dir'):
|
||||||
|
>>> # Perform operations in the new directory
|
||||||
|
>>> pass
|
||||||
|
|
||||||
|
Using as a decorator:
|
||||||
|
>>> @WorkingDirectory('/path/to/new/dir')
|
||||||
|
>>> def some_function():
|
||||||
|
>>> # Perform operations in the new directory
|
||||||
|
>>> pass
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, new_dir):
|
def __init__(self, new_dir):
|
||||||
"""Sets the working directory to 'new_dir' upon instantiation."""
|
"""Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators."""
|
||||||
self.dir = new_dir # new dir
|
self.dir = new_dir # new dir
|
||||||
self.cwd = Path.cwd().resolve() # current dir
|
self.cwd = Path.cwd().resolve() # current dir
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
"""Changes the current directory to the specified directory."""
|
"""Changes the current working directory to the specified directory upon entering the context."""
|
||||||
os.chdir(self.dir)
|
os.chdir(self.dir)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
|
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
|
||||||
"""Restore the current working directory on context exit."""
|
"""Restores the original working directory when exiting the context."""
|
||||||
os.chdir(self.cwd)
|
os.chdir(self.cwd)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -35,18 +60,16 @@ def spaces_in_path(path):
|
||||||
file/directory back to its original location.
|
file/directory back to its original location.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str | Path): The original path.
|
path (str | Path): The original path that may contain spaces.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
|
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
|
||||||
|
|
||||||
Example:
|
Examples:
|
||||||
```python
|
Use the context manager to handle paths with spaces:
|
||||||
with ultralytics.utils.files import spaces_in_path
|
>>> from ultralytics.utils.files import spaces_in_path
|
||||||
|
>>> with spaces_in_path('/path/with spaces') as new_path:
|
||||||
with spaces_in_path('/path/with spaces') as new_path:
|
>>> # Your code here
|
||||||
# Your code here
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# If path has spaces, replace them with underscores
|
# If path has spaces, replace them with underscores
|
||||||
|
|
@ -84,21 +107,35 @@ def spaces_in_path(path):
|
||||||
|
|
||||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||||
"""
|
"""
|
||||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||||
|
|
||||||
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
|
||||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||||
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a
|
||||||
directory if it does not already exist.
|
directory if it does not already exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str, pathlib.Path): Path to increment.
|
path (str | pathlib.Path): Path to increment.
|
||||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
|
exist_ok (bool): If True, the path will not be incremented and returned as-is.
|
||||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
|
sep (str): Separator to use between the path and the incrementation number.
|
||||||
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
|
mkdir (bool): Create a directory if it does not exist.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(pathlib.Path): Incremented path.
|
(pathlib.Path): Incremented path.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Increment a directory path:
|
||||||
|
>>> from pathlib import Path
|
||||||
|
>>> path = Path("runs/exp")
|
||||||
|
>>> new_path = increment_path(path)
|
||||||
|
>>> print(new_path)
|
||||||
|
runs/exp2
|
||||||
|
|
||||||
|
Increment a file path:
|
||||||
|
>>> path = Path("runs/exp/results.txt")
|
||||||
|
>>> new_path = increment_path(path)
|
||||||
|
>>> print(new_path)
|
||||||
|
runs/exp/results2.txt
|
||||||
"""
|
"""
|
||||||
path = Path(path) # os-agnostic
|
path = Path(path) # os-agnostic
|
||||||
if path.exists() and not exist_ok:
|
if path.exists() and not exist_ok:
|
||||||
|
|
@ -118,19 +155,19 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||||
|
|
||||||
|
|
||||||
def file_age(path=__file__):
|
def file_age(path=__file__):
|
||||||
"""Return days since last file update."""
|
"""Return days since the last modification of the specified file."""
|
||||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||||
return dt.days # + dt.seconds / 86400 # fractional days
|
return dt.days # + dt.seconds / 86400 # fractional days
|
||||||
|
|
||||||
|
|
||||||
def file_date(path=__file__):
|
def file_date(path=__file__):
|
||||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
"""Returns the file modification date in 'YYYY-M-D' format."""
|
||||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||||
return f"{t.year}-{t.month}-{t.day}"
|
return f"{t.year}-{t.month}-{t.day}"
|
||||||
|
|
||||||
|
|
||||||
def file_size(path):
|
def file_size(path):
|
||||||
"""Return file/dir size (MB)."""
|
"""Returns the size of a file or directory in megabytes (MB)."""
|
||||||
if isinstance(path, (str, Path)):
|
if isinstance(path, (str, Path)):
|
||||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
@ -142,7 +179,7 @@ def file_size(path):
|
||||||
|
|
||||||
|
|
||||||
def get_latest_run(search_dir="."):
|
def get_latest_run(search_dir="."):
|
||||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
"""Returns the path to the most recent 'last.pt' file in the specified directory for resuming training."""
|
||||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||||
|
|
||||||
|
|
@ -152,17 +189,15 @@ def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_name
|
||||||
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
model_names (Tuple[str, ...]): Model filenames to update.
|
||||||
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
source_dir (Path): Directory containing models and target subdirectory.
|
||||||
update_names (bool, optional): Update model names from a data YAML.
|
update_names (bool): Update model names from a data YAML.
|
||||||
|
|
||||||
Example:
|
Examples:
|
||||||
```python
|
Update specified YOLO models and save them in 'updated_models' subdirectory:
|
||||||
from ultralytics.utils.files import update_models
|
>>> from ultralytics.utils.files import update_models
|
||||||
|
>>> model_names = ("yolov8n.pt", "yolov8s.pt")
|
||||||
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
>>> update_models(model_names, source_dir=Path("/models"), update_names=True)
|
||||||
update_models(model_names)
|
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
from ultralytics import YOLO
|
from ultralytics import YOLO
|
||||||
from ultralytics.nn.autobackend import default_class_names
|
from ultralytics.nn.autobackend import default_class_names
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue