Update Tracker docstrings (#15469)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
da2797a182
commit
b7c5db94b4
10 changed files with 501 additions and 196 deletions
|
|
@ -81,7 +81,7 @@ CLI_HELP_MSG = f"""
|
|||
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
|
||||
yolo explorer data=data.yaml model=yolov8n.pt
|
||||
|
||||
6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8
|
||||
6. Streamlit real-time webcam inference GUI
|
||||
yolo streamlit-predict
|
||||
|
||||
7. Run special commands:
|
||||
|
|
|
|||
|
|
@ -15,6 +15,11 @@ class TrackState:
|
|||
Tracked (int): State when the object is successfully tracked in subsequent frames.
|
||||
Lost (int): State when the object is no longer tracked.
|
||||
Removed (int): State when the object is removed from tracking.
|
||||
|
||||
Examples:
|
||||
>>> state = TrackState.New
|
||||
>>> if state == TrackState.New:
|
||||
>>> print("Object is newly detected.")
|
||||
"""
|
||||
|
||||
New = 0
|
||||
|
|
@ -33,13 +38,13 @@ class BaseTrack:
|
|||
is_activated (bool): Flag indicating whether the track is currently active.
|
||||
state (TrackState): Current state of the track.
|
||||
history (OrderedDict): Ordered history of the track's states.
|
||||
features (list): List of features extracted from the object for tracking.
|
||||
curr_feature (any): The current feature of the object being tracked.
|
||||
features (List): List of features extracted from the object for tracking.
|
||||
curr_feature (Any): The current feature of the object being tracked.
|
||||
score (float): The confidence score of the tracking.
|
||||
start_frame (int): The frame number where tracking started.
|
||||
frame_id (int): The most recent frame ID processed by the track.
|
||||
time_since_update (int): Frames passed since the last update.
|
||||
location (tuple): The location of the object in the context of multi-camera tracking.
|
||||
location (Tuple): The location of the object in the context of multi-camera tracking.
|
||||
|
||||
Methods:
|
||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||
|
|
@ -50,12 +55,26 @@ class BaseTrack:
|
|||
mark_lost: Marks the track as lost.
|
||||
mark_removed: Marks the track as removed.
|
||||
reset_id: Resets the global track ID counter.
|
||||
|
||||
Examples:
|
||||
Initialize a new track and mark it as lost:
|
||||
>>> track = BaseTrack()
|
||||
>>> track.mark_lost()
|
||||
>>> print(track.state) # Output: 2 (TrackState.Lost)
|
||||
"""
|
||||
|
||||
_count = 0
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes a new track with unique ID and foundational tracking attributes."""
|
||||
"""
|
||||
Initializes a new track with a unique ID and foundational tracking attributes.
|
||||
|
||||
Examples:
|
||||
Initialize a new track
|
||||
>>> track = BaseTrack()
|
||||
>>> print(track.track_id)
|
||||
0
|
||||
"""
|
||||
self.track_id = 0
|
||||
self.is_activated = False
|
||||
self.state = TrackState.New
|
||||
|
|
@ -70,36 +89,36 @@ class BaseTrack:
|
|||
|
||||
@property
|
||||
def end_frame(self):
|
||||
"""Return the last frame ID of the track."""
|
||||
"""Returns the ID of the most recent frame where the object was tracked."""
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
"""Increment and return the global track ID counter."""
|
||||
"""Increment and return the next unique global track ID for object tracking."""
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""Abstract method to activate the track with provided arguments."""
|
||||
"""Activates the track with provided arguments, initializing necessary attributes for tracking."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""Abstract method to predict the next state of the track."""
|
||||
"""Predicts the next state of the track based on the current state and tracking model."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Abstract method to update the track with new observations."""
|
||||
"""Updates the track with new observations and data, modifying its state and attributes accordingly."""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
"""Mark the track as lost."""
|
||||
"""Marks the track as lost by updating its state to TrackState.Lost."""
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
"""Mark the track as removed."""
|
||||
"""Marks the track as removed by setting its state to TrackState.Removed."""
|
||||
self.state = TrackState.Removed
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Reset the global track ID counter."""
|
||||
"""Reset the global track ID counter to its initial value."""
|
||||
BaseTrack._count = 0
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ class BOTrack(STrack):
|
|||
"""
|
||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||
|
||||
This class extends the STrack class to include additional functionalities for object tracking, such as feature
|
||||
smoothing, Kalman filter prediction, and reactivation of tracks.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||
|
|
@ -34,16 +37,35 @@ class BOTrack(STrack):
|
|||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||
|
||||
Usage:
|
||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
bo_track.predict()
|
||||
bo_track.update(new_track, frame_id)
|
||||
Examples:
|
||||
Create a BOTrack instance and update its features
|
||||
>>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
|
||||
>>> bo_track.predict()
|
||||
>>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
|
||||
>>> bo_track.update(new_track, frame_id=2)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
||||
"""
|
||||
Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
|
||||
|
||||
Args:
|
||||
tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height).
|
||||
score (float): Confidence score of the detection.
|
||||
cls (int): Class ID of the detected object.
|
||||
feat (np.ndarray | None): Feature vector associated with the detection.
|
||||
feat_history (int): Maximum length of the feature history deque.
|
||||
|
||||
Examples:
|
||||
Initialize a BOTrack object with bounding box, score, class ID, and feature vector
|
||||
>>> tlwh = np.array([100, 50, 80, 120])
|
||||
>>> score = 0.9
|
||||
>>> cls = 1
|
||||
>>> feat = np.random.rand(128)
|
||||
>>> bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
"""
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
|
|
@ -54,7 +76,7 @@ class BOTrack(STrack):
|
|||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
"""Update features vector and smooth it using exponential moving average."""
|
||||
"""Update the feature vector and apply exponential moving average smoothing."""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
|
|
@ -65,7 +87,7 @@ class BOTrack(STrack):
|
|||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""Predicts the mean and covariance using Kalman filter."""
|
||||
"""Predicts the object's future state using the Kalman filter to update its mean and covariance."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
|
|
@ -80,14 +102,14 @@ class BOTrack(STrack):
|
|||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
||||
"""Updates the YOLOv8 instance with new track information and the current frame ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
||||
"""Returns the current bounding box position in `(top left x, top left y, width, height)` format."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -96,7 +118,7 @@ class BOTrack(STrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
||||
"""Predicts the mean and covariance for multiple object tracks using a shared Kalman filter."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
|
|
@ -111,12 +133,12 @@ class BOTrack(STrack):
|
|||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
||||
"""Converts tlwh bounding box coordinates to xywh format."""
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
||||
"""Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
|
@ -129,9 +151,9 @@ class BOTSORT(BYTETracker):
|
|||
Attributes:
|
||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
args (Any): Parsed command-line arguments containing tracking parameters.
|
||||
|
||||
Methods:
|
||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||
|
|
@ -139,17 +161,29 @@ class BOTSORT(BYTETracker):
|
|||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||
|
||||
Usage:
|
||||
bot_sort = BOTSORT(args, frame_rate)
|
||||
bot_sort.init_track(dets, scores, cls, img)
|
||||
bot_sort.multi_predict(tracks)
|
||||
Examples:
|
||||
Initialize BOTSORT and process detections
|
||||
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||
>>> bot_sort.init_track(dets, scores, cls, img)
|
||||
>>> bot_sort.multi_predict(tracks)
|
||||
|
||||
Note:
|
||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
"""
|
||||
Initialize YOLOv8 object with ReID module and GMC algorithm.
|
||||
|
||||
Args:
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
frame_rate (int): Frame rate of the video being processed.
|
||||
|
||||
Examples:
|
||||
Initialize BOTSORT with command-line arguments and a specified frame rate:
|
||||
>>> args = parse_args()
|
||||
>>> bot_sort = BOTSORT(args, frame_rate=30)
|
||||
"""
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
|
|
@ -161,11 +195,11 @@ class BOTSORT(BYTETracker):
|
|||
self.gmc = GMC(method=args.gmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
"""Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
"""Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
|
|
@ -175,7 +209,7 @@ class BOTSORT(BYTETracker):
|
|||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
"""Calculates distances between tracks and detections using IoU and optionally ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = dists > self.proximity_thresh
|
||||
|
||||
|
|
@ -190,10 +224,10 @@ class BOTSORT(BYTETracker):
|
|||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
"""Predicts the mean and covariance of multiple object tracks using a shared Kalman filter."""
|
||||
BOTrack.multi_predict(tracks)
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
"""Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
|
||||
super().reset()
|
||||
self.gmc.reset_params()
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class STrack(BaseTrack):
|
|||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||
score (float): Confidence score of the track.
|
||||
tracklet_len (int): Length of the tracklet.
|
||||
cls (any): Class label for the object.
|
||||
cls (Any): Class label for the object.
|
||||
idx (int): Index or identifier for the object.
|
||||
frame_id (int): Current frame ID.
|
||||
start_frame (int): Frame where the object was first detected.
|
||||
|
|
@ -39,12 +39,31 @@ class STrack(BaseTrack):
|
|||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-aspect-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
|
||||
Examples:
|
||||
Initialize and activate a new track
|
||||
>>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person')
|
||||
>>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, xywh, score, cls):
|
||||
"""Initialize new STrack instance."""
|
||||
"""
|
||||
Initialize a new STrack instance.
|
||||
|
||||
Args:
|
||||
xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
|
||||
(x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
|
||||
score (float): Confidence score of the detection.
|
||||
cls (Any): Class label for the detected object.
|
||||
|
||||
Examples:
|
||||
>>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
|
||||
>>> score = 0.9
|
||||
>>> cls = 'person'
|
||||
>>> track = STrack(xywh, score, cls)
|
||||
"""
|
||||
super().__init__()
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
||||
|
|
@ -60,7 +79,7 @@ class STrack(BaseTrack):
|
|||
self.angle = xywh[4] if len(xywh) == 6 else None
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
"""Predicts the next state (mean and covariance) of the object using the Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
|
|
@ -68,7 +87,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
||||
"""Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
|
|
@ -83,7 +102,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
"""Update state tracks positions and covariances using a homography matrix."""
|
||||
"""Update state tracks positions and covariances using a homography matrix for multiple tracks."""
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
|
@ -101,7 +120,7 @@ class STrack(BaseTrack):
|
|||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet."""
|
||||
"""Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||
|
|
@ -114,7 +133,7 @@ class STrack(BaseTrack):
|
|||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
"""Reactivates a previously lost track using new detection data and updates its state and attributes."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(
|
||||
self.mean, self.covariance, self.convert_coords(new_track.tlwh)
|
||||
)
|
||||
|
|
@ -136,6 +155,12 @@ class STrack(BaseTrack):
|
|||
Args:
|
||||
new_track (STrack): The new track containing updated information.
|
||||
frame_id (int): The ID of the current frame.
|
||||
|
||||
Examples:
|
||||
Update the state of a track with new detection information
|
||||
>>> track = STrack([100, 200, 50, 80, 0.9, 1])
|
||||
>>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
|
||||
>>> track.update(new_track, 2)
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
|
@ -158,7 +183,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
||||
"""Returns the bounding box in top-left-width-height format from the current state estimate."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -168,16 +193,14 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def xyxy(self):
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
"""Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
||||
height.
|
||||
"""
|
||||
"""Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
|
|
@ -185,14 +208,14 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def xywh(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height)."""
|
||||
"""Returns the current position of the bounding box in (center x, center y, width, height) format."""
|
||||
ret = np.asarray(self.tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def xywha(self):
|
||||
"""Get current position in bounding box format (center x, center y, width, height, angle)."""
|
||||
"""Returns position in (center x, center y, width, height, angle) format, warning if angle is missing."""
|
||||
if self.angle is None:
|
||||
LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.")
|
||||
return self.xywh
|
||||
|
|
@ -200,12 +223,12 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def result(self):
|
||||
"""Get current tracking results."""
|
||||
"""Returns the current tracking results in the appropriate bounding box format."""
|
||||
coords = self.xyxy if self.angle is None else self.xywha
|
||||
return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
"""Returns a string representation of the STrack object including start frame, end frame, and track ID."""
|
||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||
|
||||
|
||||
|
|
@ -213,18 +236,18 @@ class BYTETracker:
|
|||
"""
|
||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||
|
||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
||||
predicting the new object locations, and performs data association.
|
||||
Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence.
|
||||
It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting
|
||||
the new object locations, and performs data association.
|
||||
|
||||
Attributes:
|
||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (list[STrack]): List of lost tracks.
|
||||
removed_stracks (list[STrack]): List of removed tracks.
|
||||
tracked_stracks (List[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (List[STrack]): List of lost tracks.
|
||||
removed_stracks (List[STrack]): List of removed tracks.
|
||||
frame_id (int): The current frame ID.
|
||||
args (namespace): Command-line arguments.
|
||||
args (Namespace): Command-line arguments.
|
||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||
kalman_filter (object): Kalman Filter object.
|
||||
kalman_filter (KalmanFilterXYAH): Kalman Filter object.
|
||||
|
||||
Methods:
|
||||
update(results, img=None): Updates object tracker with new detections.
|
||||
|
|
@ -236,10 +259,27 @@ class BYTETracker:
|
|||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU.
|
||||
|
||||
Examples:
|
||||
Initialize BYTETracker and update with detection results
|
||||
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||
>>> results = yolo_model.detect(image)
|
||||
>>> tracked_objects = tracker.update(results)
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
"""
|
||||
Initialize a BYTETracker instance for object tracking.
|
||||
|
||||
Args:
|
||||
args (Namespace): Command-line arguments containing tracking parameters.
|
||||
frame_rate (int): Frame rate of the video sequence.
|
||||
|
||||
Examples:
|
||||
Initialize BYTETracker with command-line arguments and a frame rate of 30
|
||||
>>> args = Namespace(track_buffer=30)
|
||||
>>> tracker = BYTETracker(args, frame_rate=30)
|
||||
"""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
|
@ -251,7 +291,7 @@ class BYTETracker:
|
|||
self.reset_id()
|
||||
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
"""Updates the tracker with new detections and returns the current list of tracked objects."""
|
||||
self.frame_id += 1
|
||||
activated_stracks = []
|
||||
refind_stracks = []
|
||||
|
|
@ -365,31 +405,31 @@ class BYTETracker:
|
|||
return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
"""Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
"""Initializes object tracking with given detections, scores, and class labels using the STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IoU and fuses scores."""
|
||||
"""Calculates the distance between tracks and detections using IoU and optionally fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
if self.args.fuse_score:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
"""Predict the next states for multiple tracks using Kalman filter."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Resets the ID counter of STrack."""
|
||||
"""Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
|
||||
STrack.reset_id()
|
||||
|
||||
def reset(self):
|
||||
"""Reset tracker."""
|
||||
"""Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
|
@ -399,7 +439,7 @@ class BYTETracker:
|
|||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
"""Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
|
|
@ -414,20 +454,13 @@ class BYTETracker:
|
|||
|
||||
@staticmethod
|
||||
def sub_stracks(tlista, tlistb):
|
||||
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
||||
stracks = {t.track_id: t for t in tlista}
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
"""
|
||||
"""Filters out the stracks present in the second list from the first list."""
|
||||
track_ids_b = {t.track_id for t in tlistb}
|
||||
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IoU distance."""
|
||||
"""Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
|
|
|
|||
|
|
@ -21,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|||
|
||||
Args:
|
||||
predictor (object): The predictor object to initialize trackers for.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
|
||||
Examples:
|
||||
Initialize trackers for a predictor object:
|
||||
>>> predictor = SomePredictorClass()
|
||||
>>> on_predict_start(predictor, persist=True)
|
||||
"""
|
||||
if hasattr(predictor, "trackers") and persist:
|
||||
return
|
||||
|
|
@ -51,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None
|
|||
|
||||
Args:
|
||||
predictor (object): The predictor object containing the predictions.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
Examples:
|
||||
Postprocess predictions and update with tracking
|
||||
>>> predictor = YourPredictorClass()
|
||||
>>> on_predict_postprocess_end(predictor, persist=True)
|
||||
"""
|
||||
path, im0s = predictor.batch[:2]
|
||||
|
||||
|
|
@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None:
|
|||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
Examples:
|
||||
Register tracking callbacks to a YOLO model
|
||||
>>> model = YOLOModel()
|
||||
>>> register_tracker(model, persist=True)
|
||||
"""
|
||||
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
|
||||
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
|
||||
|
|
|
|||
|
|
@ -19,27 +19,39 @@ class GMC:
|
|||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Factor by which to downscale the frames for processing.
|
||||
prevFrame (np.ndarray): Stores the previous frame for tracking.
|
||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
||||
prevKeyPoints (List): Stores the keypoints from the previous frame.
|
||||
prevDescriptors (np.ndarray): Stores the descriptors from the previous frame.
|
||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||
|
||||
Methods:
|
||||
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
||||
and downscale factor.
|
||||
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
||||
provided detections.
|
||||
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
||||
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
||||
__init__: Initializes a GMC object with the specified method and downscale factor.
|
||||
apply: Applies the chosen method to a raw frame and optionally uses provided detections.
|
||||
applyEcc: Applies the ECC algorithm to a raw frame.
|
||||
applyFeatures: Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||
applySparseOptFlow: Applies the Sparse Optical Flow method to a raw frame.
|
||||
reset_params: Resets the internal parameters of the GMC object.
|
||||
|
||||
Examples:
|
||||
Create a GMC object and apply it to a frame
|
||||
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
||||
>>> frame = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
>>> processed_frame = gmc.apply(frame)
|
||||
>>> print(processed_frame)
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
|
||||
def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
|
||||
"""
|
||||
Initialize a video tracker with specified parameters.
|
||||
Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
|
||||
|
||||
Args:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Downscale factor for processing frames.
|
||||
|
||||
Examples:
|
||||
Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
|
||||
>>> gmc = GMC(method='sparseOptFlow', downscale=2)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -79,20 +91,21 @@ class GMC:
|
|||
|
||||
def apply(self, raw_frame: np.array, detections: list = None) -> np.array:
|
||||
"""
|
||||
Apply object detection on a raw frame using specified method.
|
||||
Apply object detection on a raw frame using the specified method.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||
detections (List | None): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
(np.ndarray): Processed frame with applied object detection.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
>>> gmc = GMC(method='sparseOptFlow')
|
||||
>>> raw_frame = np.random.rand(480, 640, 3)
|
||||
>>> processed_frame = gmc.apply(raw_frame)
|
||||
>>> print(processed_frame.shape)
|
||||
(480, 640, 3)
|
||||
"""
|
||||
if self.method in {"orb", "sift"}:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
|
|
@ -105,19 +118,20 @@ class GMC:
|
|||
|
||||
def applyEcc(self, raw_frame: np.array) -> np.array:
|
||||
"""
|
||||
Apply ECC algorithm to a raw frame.
|
||||
Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
(np.ndarray): The processed frame with the applied ECC transformation.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
>>> gmc = GMC(method='ecc')
|
||||
>>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
||||
>>> print(processed_frame)
|
||||
[[1. 0. 0.]
|
||||
[0. 1. 0.]]
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
|
|
@ -127,8 +141,6 @@ class GMC:
|
|||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
|
|
@ -154,17 +166,18 @@ class GMC:
|
|||
Apply feature-based methods like ORB or SIFT to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
detections (list): List of detections to be used in the processing.
|
||||
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||
detections (List | None): List of detections to be used in the processing.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
>>> gmc = GMC(method='orb')
|
||||
>>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
>>> processed_frame = gmc.applyFeatures(raw_frame)
|
||||
>>> print(processed_frame.shape)
|
||||
(2, 3)
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
|
|
@ -296,16 +309,17 @@ class GMC:
|
|||
Apply Sparse Optical Flow method to a raw frame.
|
||||
|
||||
Args:
|
||||
raw_frame (np.ndarray): The raw frame to be processed.
|
||||
raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Processed frame.
|
||||
(np.ndarray): Processed frame with shape (2, 3).
|
||||
|
||||
Examples:
|
||||
>>> gmc = GMC()
|
||||
>>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]]))
|
||||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
>>> result = gmc.applySparseOptFlow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
|
||||
>>> print(result)
|
||||
[[1. 0. 0.]
|
||||
[0. 1. 0.]]
|
||||
"""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
|
|
@ -356,7 +370,7 @@ class GMC:
|
|||
return H
|
||||
|
||||
def reset_params(self) -> None:
|
||||
"""Reset parameters."""
|
||||
"""Reset the internal parameters including previous frame, keypoints, and descriptors."""
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
|
|
|||
|
|
@ -6,17 +6,49 @@ import scipy.linalg
|
|||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
|
||||
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
||||
ratio a, height h, and their respective velocities.
|
||||
Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
|
||||
(x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
|
||||
respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
|
||||
taken as a direct observation of the state space (linear observation model).
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
Attributes:
|
||||
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
|
||||
_update_mat (np.ndarray): The update matrix for the Kalman filter.
|
||||
_std_weight_position (float): Standard deviation weight for position.
|
||||
_std_weight_velocity (float): Standard deviation weight for velocity.
|
||||
|
||||
Methods:
|
||||
initiate: Creates a track from an unassociated measurement.
|
||||
predict: Runs the Kalman filter prediction step.
|
||||
project: Projects the state distribution to measurement space.
|
||||
multi_predict: Runs the Kalman filter prediction step (vectorized version).
|
||||
update: Runs the Kalman filter correction step.
|
||||
gating_distance: Computes the gating distance between state distribution and measurements.
|
||||
|
||||
Examples:
|
||||
Initialize the Kalman filter and create a track from a measurement
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> measurement = np.array([100, 200, 1.5, 50])
|
||||
>>> mean, covariance = kf.initiate(measurement)
|
||||
>>> print(mean)
|
||||
>>> print(covariance)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
||||
"""
|
||||
Initialize Kalman filter model matrices with motion and observation uncertainty weights.
|
||||
|
||||
The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
|
||||
represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
|
||||
velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
|
||||
observation model for bounding box location.
|
||||
|
||||
Examples:
|
||||
Initialize a Kalman filter for tracking:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
"""
|
||||
ndim, dt = 4, 1.0
|
||||
|
||||
# Create Kalman filter model matrices
|
||||
|
|
@ -32,15 +64,20 @@ class KalmanFilterXYAH:
|
|||
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
Create a track from an unassociated measurement.
|
||||
|
||||
Args:
|
||||
measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
|
||||
and height h.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional)
|
||||
of the new track. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> measurement = np.array([100, 50, 1.5, 200])
|
||||
>>> mean, covariance = kf.initiate(measurement)
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
|
|
@ -64,12 +101,18 @@ class KalmanFilterXYAH:
|
|||
Run Kalman filter prediction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3],
|
||||
|
|
@ -100,6 +143,12 @@ class KalmanFilterXYAH:
|
|||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> projected_mean, projected_covariance = kf.project(mean, covariance)
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3],
|
||||
|
|
@ -115,15 +164,21 @@ class KalmanFilterXYAH:
|
|||
|
||||
def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple:
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
Run Kalman filter prediction step for multiple object states (Vectorized version).
|
||||
|
||||
Args:
|
||||
mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
(tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states.
|
||||
The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities
|
||||
are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> mean = np.random.rand(10, 8) # 10 object states
|
||||
>>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
|
||||
>>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3],
|
||||
|
|
@ -160,6 +215,13 @@ class KalmanFilterXYAH:
|
|||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> measurement = np.array([1, 1, 1, 1])
|
||||
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
|
|
@ -182,23 +244,31 @@ class KalmanFilterXYAH:
|
|||
metric: str = "maha",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
||||
otherwise 2.
|
||||
Compute gating distance between state distribution and measurements.
|
||||
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
|
||||
distribution has 4 degrees of freedom, otherwise 2.
|
||||
|
||||
Args:
|
||||
mean (ndarray): Mean vector over the state distribution (8 dimensional).
|
||||
covariance (ndarray): Covariance of the state distribution (8x8 dimensional).
|
||||
measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y)
|
||||
is the bounding box center position, a the aspect ratio, and h the height.
|
||||
only_position (bool, optional): If True, distance computation is done with respect to the bounding box
|
||||
center position only. Defaults to False.
|
||||
metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the
|
||||
squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'.
|
||||
measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
|
||||
bounding box center position, a the aspect ratio, and h the height.
|
||||
only_position (bool): If True, distance computation is done with respect to box center position only.
|
||||
metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared
|
||||
Euclidean distance and 'maha' for the squared Mahalanobis distance.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
|
||||
(mean, covariance) and `measurements[i]`.
|
||||
|
||||
Examples:
|
||||
Compute gating distance using Mahalanobis metric:
|
||||
>>> kf = KalmanFilterXYAH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
|
||||
>>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha')
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
|
|
@ -218,13 +288,33 @@ class KalmanFilterXYAH:
|
|||
|
||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
"""
|
||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
|
||||
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
||||
w, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||
Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
|
||||
(x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
|
||||
The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
|
||||
observation of the state space (linear observation model).
|
||||
|
||||
Attributes:
|
||||
_motion_mat (np.ndarray): The motion matrix for the Kalman filter.
|
||||
_update_mat (np.ndarray): The update matrix for the Kalman filter.
|
||||
_std_weight_position (float): Standard deviation weight for position.
|
||||
_std_weight_velocity (float): Standard deviation weight for velocity.
|
||||
|
||||
Methods:
|
||||
initiate: Creates a track from an unassociated measurement.
|
||||
predict: Runs the Kalman filter prediction step.
|
||||
project: Projects the state distribution to measurement space.
|
||||
multi_predict: Runs the Kalman filter prediction step in a vectorized manner.
|
||||
update: Runs the Kalman filter correction step.
|
||||
|
||||
Examples:
|
||||
Create a Kalman filter and initialize a track
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> measurement = np.array([100, 50, 20, 40])
|
||||
>>> mean, covariance = kf.initiate(measurement)
|
||||
>>> print(mean)
|
||||
>>> print(covariance)
|
||||
"""
|
||||
|
||||
def initiate(self, measurement: np.ndarray) -> tuple:
|
||||
|
|
@ -237,6 +327,22 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional)
|
||||
of the new track. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> measurement = np.array([100, 50, 20, 40])
|
||||
>>> mean, covariance = kf.initiate(measurement)
|
||||
>>> print(mean)
|
||||
[100. 50. 20. 40. 0. 0. 0. 0.]
|
||||
>>> print(covariance)
|
||||
[[ 4. 0. 0. 0. 0. 0. 0. 0.]
|
||||
[ 0. 4. 0. 0. 0. 0. 0. 0.]
|
||||
[ 0. 0. 4. 0. 0. 0. 0. 0.]
|
||||
[ 0. 0. 0. 4. 0. 0. 0. 0.]
|
||||
[ 0. 0. 0. 0. 0.25 0. 0. 0.]
|
||||
[ 0. 0. 0. 0. 0. 0.25 0. 0.]
|
||||
[ 0. 0. 0. 0. 0. 0. 0.25 0.]
|
||||
[ 0. 0. 0. 0. 0. 0. 0. 0.25]]
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
|
|
@ -260,12 +366,18 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
Run Kalman filter prediction step.
|
||||
|
||||
Args:
|
||||
mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step.
|
||||
covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2],
|
||||
|
|
@ -296,6 +408,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> projected_mean, projected_cov = kf.project(mean, covariance)
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2],
|
||||
|
|
@ -320,6 +438,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved
|
||||
velocities are initialized to 0 mean.
|
||||
|
||||
Examples:
|
||||
>>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
|
||||
>>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2],
|
||||
|
|
@ -356,5 +480,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
|
||||
Returns:
|
||||
(tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution.
|
||||
|
||||
Examples:
|
||||
>>> kf = KalmanFilterXYWH()
|
||||
>>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
|
||||
>>> covariance = np.eye(8)
|
||||
>>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
|
||||
>>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
|
||||
"""
|
||||
return super().update(mean, covariance, measurement)
|
||||
|
|
|
|||
|
|
@ -19,18 +19,23 @@ except (ImportError, AssertionError, AttributeError):
|
|||
|
||||
def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple:
|
||||
"""
|
||||
Perform linear assignment using scipy or lap.lapjv.
|
||||
Perform linear assignment using either the scipy or lap.lapjv method.
|
||||
|
||||
Args:
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
||||
thresh (float): Threshold for considering an assignment valid.
|
||||
use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True.
|
||||
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
||||
|
||||
Returns:
|
||||
Tuple with:
|
||||
- matched indices
|
||||
- unmatched indices from 'a'
|
||||
- unmatched indices from 'b'
|
||||
(tuple): A tuple containing:
|
||||
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
||||
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
||||
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
||||
|
||||
Examples:
|
||||
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
>>> thresh = 5.0
|
||||
>>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
|
|
@ -68,6 +73,12 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
|
|||
|
||||
Returns:
|
||||
(np.ndarray): Cost matrix computed based on IoU.
|
||||
|
||||
Examples:
|
||||
Compute IoU distance between two sets of tracks
|
||||
>>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
|
||||
>>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
|
||||
>>> cost_matrix = iou_distance(atracks, btracks)
|
||||
"""
|
||||
|
||||
if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
|
||||
|
|
@ -98,12 +109,19 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -
|
|||
Compute distance between tracks and detections based on embeddings.
|
||||
|
||||
Args:
|
||||
tracks (list[STrack]): List of tracks.
|
||||
detections (list[BaseTrack]): List of detections.
|
||||
metric (str, optional): Metric for distance computation. Defaults to 'cosine'.
|
||||
tracks (list[STrack]): List of tracks, where each track contains embedding features.
|
||||
detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
|
||||
metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Cost matrix computed based on embeddings.
|
||||
(np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
|
||||
and M is the number of detections.
|
||||
|
||||
Examples:
|
||||
Compute the embedding distance between tracks and detections using cosine metric
|
||||
>>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
|
||||
>>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
|
||||
>>> cost_matrix = embedding_distance(tracks, detections, metric='cosine')
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
|
|
@ -122,11 +140,17 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
|
|||
Fuses cost matrix with detection scores to produce a single similarity matrix.
|
||||
|
||||
Args:
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments.
|
||||
detections (list[BaseTrack]): List of detections with scores.
|
||||
cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
|
||||
detections (list[BaseTrack]): List of detections, each containing a score attribute.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Fused similarity matrix.
|
||||
(np.ndarray): Fused similarity matrix with shape (N, M).
|
||||
|
||||
Examples:
|
||||
Fuse a cost matrix with detection scores
|
||||
>>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
|
||||
>>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
|
||||
>>> fused_matrix = fuse_score(cost_matrix, detections)
|
||||
"""
|
||||
|
||||
if cost_matrix.size == 0:
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ PYTHON_VERSION = platform.python_version()
|
|||
TORCH_VERSION = torch.__version__
|
||||
TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
|
||||
HELP_MSG = """
|
||||
Usage examples for running Ultralytics YOLO:
|
||||
Examples for running Ultralytics:
|
||||
|
||||
1. Install the ultralytics package:
|
||||
|
||||
|
|
|
|||
|
|
@ -11,19 +11,44 @@ from pathlib import Path
|
|||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
|
||||
"""
|
||||
A context manager and decorator for temporarily changing the working directory.
|
||||
|
||||
This class allows for the temporary change of the working directory using a context manager or decorator.
|
||||
It ensures that the original working directory is restored after the context or decorated function completes.
|
||||
|
||||
Attributes:
|
||||
dir (Path): The new directory to switch to.
|
||||
cwd (Path): The original current working directory before the switch.
|
||||
|
||||
Methods:
|
||||
__enter__: Changes the current directory to the specified directory.
|
||||
__exit__: Restores the original working directory on context exit.
|
||||
|
||||
Examples:
|
||||
Using as a context manager:
|
||||
>>> with WorkingDirectory('/path/to/new/dir'):
|
||||
>>> # Perform operations in the new directory
|
||||
>>> pass
|
||||
|
||||
Using as a decorator:
|
||||
>>> @WorkingDirectory('/path/to/new/dir')
|
||||
>>> def some_function():
|
||||
>>> # Perform operations in the new directory
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def __init__(self, new_dir):
|
||||
"""Sets the working directory to 'new_dir' upon instantiation."""
|
||||
"""Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators."""
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
"""Changes the current directory to the specified directory."""
|
||||
"""Changes the current working directory to the specified directory upon entering the context."""
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
|
||||
"""Restore the current working directory on context exit."""
|
||||
"""Restores the original working directory when exiting the context."""
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
|
|
@ -35,18 +60,16 @@ def spaces_in_path(path):
|
|||
file/directory back to its original location.
|
||||
|
||||
Args:
|
||||
path (str | Path): The original path.
|
||||
path (str | Path): The original path that may contain spaces.
|
||||
|
||||
Yields:
|
||||
(Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path.
|
||||
|
||||
Example:
|
||||
```python
|
||||
with ultralytics.utils.files import spaces_in_path
|
||||
|
||||
with spaces_in_path('/path/with spaces') as new_path:
|
||||
# Your code here
|
||||
```
|
||||
Examples:
|
||||
Use the context manager to handle paths with spaces:
|
||||
>>> from ultralytics.utils.files import spaces_in_path
|
||||
>>> with spaces_in_path('/path/with spaces') as new_path:
|
||||
>>> # Your code here
|
||||
"""
|
||||
|
||||
# If path has spaces, replace them with underscores
|
||||
|
|
@ -84,21 +107,35 @@ def spaces_in_path(path):
|
|||
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
"""
|
||||
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
|
||||
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
|
||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
|
||||
number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a
|
||||
directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
path (str, pathlib.Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''.
|
||||
mkdir (bool, optional): Create a directory if it does not exist. Defaults to False.
|
||||
path (str | pathlib.Path): Path to increment.
|
||||
exist_ok (bool): If True, the path will not be incremented and returned as-is.
|
||||
sep (str): Separator to use between the path and the incrementation number.
|
||||
mkdir (bool): Create a directory if it does not exist.
|
||||
|
||||
Returns:
|
||||
(pathlib.Path): Incremented path.
|
||||
|
||||
Examples:
|
||||
Increment a directory path:
|
||||
>>> from pathlib import Path
|
||||
>>> path = Path("runs/exp")
|
||||
>>> new_path = increment_path(path)
|
||||
>>> print(new_path)
|
||||
runs/exp2
|
||||
|
||||
Increment a file path:
|
||||
>>> path = Path("runs/exp/results.txt")
|
||||
>>> new_path = increment_path(path)
|
||||
>>> print(new_path)
|
||||
runs/exp/results2.txt
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
|
|
@ -118,19 +155,19 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
|||
|
||||
|
||||
def file_age(path=__file__):
|
||||
"""Return days since last file update."""
|
||||
"""Return days since the last modification of the specified file."""
|
||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path=__file__):
|
||||
"""Return human-readable file modification date, i.e. '2021-3-26'."""
|
||||
"""Returns the file modification date in 'YYYY-M-D' format."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f"{t.year}-{t.month}-{t.day}"
|
||||
|
||||
|
||||
def file_size(path):
|
||||
"""Return file/dir size (MB)."""
|
||||
"""Returns the size of a file or directory in megabytes (MB)."""
|
||||
if isinstance(path, (str, Path)):
|
||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||
path = Path(path)
|
||||
|
|
@ -142,7 +179,7 @@ def file_size(path):
|
|||
|
||||
|
||||
def get_latest_run(search_dir="."):
|
||||
"""Return path to most recent 'last.pt' in /runs (i.e. to --resume from)."""
|
||||
"""Returns the path to the most recent 'last.pt' file in the specified directory for resuming training."""
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
||||
|
|
@ -152,17 +189,15 @@ def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_name
|
|||
Updates and re-saves specified YOLO models in an 'updated_models' subdirectory.
|
||||
|
||||
Args:
|
||||
model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt").
|
||||
source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory.
|
||||
update_names (bool, optional): Update model names from a data YAML.
|
||||
model_names (Tuple[str, ...]): Model filenames to update.
|
||||
source_dir (Path): Directory containing models and target subdirectory.
|
||||
update_names (bool): Update model names from a data YAML.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from ultralytics.utils.files import update_models
|
||||
|
||||
model_names = (f"rtdetr-{size}.pt" for size in "lx")
|
||||
update_models(model_names)
|
||||
```
|
||||
Examples:
|
||||
Update specified YOLO models and save them in 'updated_models' subdirectory:
|
||||
>>> from ultralytics.utils.files import update_models
|
||||
>>> model_names = ("yolov8n.pt", "yolov8s.pt")
|
||||
>>> update_models(model_names, source_dir=Path("/models"), update_names=True)
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.autobackend import default_class_names
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue