diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 4a23ec42..628e1ce6 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -81,7 +81,7 @@ CLI_HELP_MSG = f""" 5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API yolo explorer data=data.yaml model=yolov8n.pt - 6. Streamlit real-time object detection on your webcam with Ultralytics YOLOv8 + 6. Streamlit real-time webcam inference GUI yolo streamlit-predict 7. Run special commands: diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py index c900cac4..bd2f1396 100644 --- a/ultralytics/trackers/basetrack.py +++ b/ultralytics/trackers/basetrack.py @@ -15,6 +15,11 @@ class TrackState: Tracked (int): State when the object is successfully tracked in subsequent frames. Lost (int): State when the object is no longer tracked. Removed (int): State when the object is removed from tracking. + + Examples: + >>> state = TrackState.New + >>> if state == TrackState.New: + >>> print("Object is newly detected.") """ New = 0 @@ -33,13 +38,13 @@ class BaseTrack: is_activated (bool): Flag indicating whether the track is currently active. state (TrackState): Current state of the track. history (OrderedDict): Ordered history of the track's states. - features (list): List of features extracted from the object for tracking. - curr_feature (any): The current feature of the object being tracked. + features (List): List of features extracted from the object for tracking. + curr_feature (Any): The current feature of the object being tracked. score (float): The confidence score of the tracking. start_frame (int): The frame number where tracking started. frame_id (int): The most recent frame ID processed by the track. time_since_update (int): Frames passed since the last update. - location (tuple): The location of the object in the context of multi-camera tracking. + location (Tuple): The location of the object in the context of multi-camera tracking. Methods: end_frame: Returns the ID of the last frame where the object was tracked. @@ -50,12 +55,26 @@ class BaseTrack: mark_lost: Marks the track as lost. mark_removed: Marks the track as removed. reset_id: Resets the global track ID counter. + + Examples: + Initialize a new track and mark it as lost: + >>> track = BaseTrack() + >>> track.mark_lost() + >>> print(track.state) # Output: 2 (TrackState.Lost) """ _count = 0 def __init__(self): - """Initializes a new track with unique ID and foundational tracking attributes.""" + """ + Initializes a new track with a unique ID and foundational tracking attributes. + + Examples: + Initialize a new track + >>> track = BaseTrack() + >>> print(track.track_id) + 0 + """ self.track_id = 0 self.is_activated = False self.state = TrackState.New @@ -70,36 +89,36 @@ class BaseTrack: @property def end_frame(self): - """Return the last frame ID of the track.""" + """Returns the ID of the most recent frame where the object was tracked.""" return self.frame_id @staticmethod def next_id(): - """Increment and return the global track ID counter.""" + """Increment and return the next unique global track ID for object tracking.""" BaseTrack._count += 1 return BaseTrack._count def activate(self, *args): - """Abstract method to activate the track with provided arguments.""" + """Activates the track with provided arguments, initializing necessary attributes for tracking.""" raise NotImplementedError def predict(self): - """Abstract method to predict the next state of the track.""" + """Predicts the next state of the track based on the current state and tracking model.""" raise NotImplementedError def update(self, *args, **kwargs): - """Abstract method to update the track with new observations.""" + """Updates the track with new observations and data, modifying its state and attributes accordingly.""" raise NotImplementedError def mark_lost(self): - """Mark the track as lost.""" + """Marks the track as lost by updating its state to TrackState.Lost.""" self.state = TrackState.Lost def mark_removed(self): - """Mark the track as removed.""" + """Marks the track as removed by setting its state to TrackState.Removed.""" self.state = TrackState.Removed @staticmethod def reset_id(): - """Reset the global track ID counter.""" + """Reset the global track ID counter to its initial value.""" BaseTrack._count = 0 diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py index 862e217d..1f10dc7f 100644 --- a/ultralytics/trackers/bot_sort.py +++ b/ultralytics/trackers/bot_sort.py @@ -15,6 +15,9 @@ class BOTrack(STrack): """ An extended version of the STrack class for YOLOv8, adding object tracking features. + This class extends the STrack class to include additional functionalities for object tracking, such as feature + smoothing, Kalman filter prediction, and reactivation of tracks. + Attributes: shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. smooth_feat (np.ndarray): Smoothed feature vector. @@ -34,16 +37,35 @@ class BOTrack(STrack): convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format. tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`. - Usage: - bo_track = BOTrack(tlwh, score, cls, feat) - bo_track.predict() - bo_track.update(new_track, frame_id) + Examples: + Create a BOTrack instance and update its features + >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128)) + >>> bo_track.predict() + >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128)) + >>> bo_track.update(new_track, frame_id=2) """ shared_kalman = KalmanFilterXYWH() def __init__(self, tlwh, score, cls, feat=None, feat_history=50): - """Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features.""" + """ + Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features. + + Args: + tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height). + score (float): Confidence score of the detection. + cls (int): Class ID of the detected object. + feat (np.ndarray | None): Feature vector associated with the detection. + feat_history (int): Maximum length of the feature history deque. + + Examples: + Initialize a BOTrack object with bounding box, score, class ID, and feature vector + >>> tlwh = np.array([100, 50, 80, 120]) + >>> score = 0.9 + >>> cls = 1 + >>> feat = np.random.rand(128) + >>> bo_track = BOTrack(tlwh, score, cls, feat) + """ super().__init__(tlwh, score, cls) self.smooth_feat = None @@ -54,7 +76,7 @@ class BOTrack(STrack): self.alpha = 0.9 def update_features(self, feat): - """Update features vector and smooth it using exponential moving average.""" + """Update the feature vector and apply exponential moving average smoothing.""" feat /= np.linalg.norm(feat) self.curr_feat = feat if self.smooth_feat is None: @@ -65,7 +87,7 @@ class BOTrack(STrack): self.smooth_feat /= np.linalg.norm(self.smooth_feat) def predict(self): - """Predicts the mean and covariance using Kalman filter.""" + """Predicts the object's future state using the Kalman filter to update its mean and covariance.""" mean_state = self.mean.copy() if self.state != TrackState.Tracked: mean_state[6] = 0 @@ -80,14 +102,14 @@ class BOTrack(STrack): super().re_activate(new_track, frame_id, new_id) def update(self, new_track, frame_id): - """Update the YOLOv8 instance with new track and frame ID.""" + """Updates the YOLOv8 instance with new track information and the current frame ID.""" if new_track.curr_feat is not None: self.update_features(new_track.curr_feat) super().update(new_track, frame_id) @property def tlwh(self): - """Get current position in bounding box format `(top left x, top left y, width, height)`.""" + """Returns the current bounding box position in `(top left x, top left y, width, height)` format.""" if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() @@ -96,7 +118,7 @@ class BOTrack(STrack): @staticmethod def multi_predict(stracks): - """Predicts the mean and covariance of multiple object tracks using shared Kalman filter.""" + """Predicts the mean and covariance for multiple object tracks using a shared Kalman filter.""" if len(stracks) <= 0: return multi_mean = np.asarray([st.mean.copy() for st in stracks]) @@ -111,12 +133,12 @@ class BOTrack(STrack): stracks[i].covariance = cov def convert_coords(self, tlwh): - """Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format.""" + """Converts tlwh bounding box coordinates to xywh format.""" return self.tlwh_to_xywh(tlwh) @staticmethod def tlwh_to_xywh(tlwh): - """Convert bounding box to format `(center x, center y, width, height)`.""" + """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.""" ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 return ret @@ -129,9 +151,9 @@ class BOTSORT(BYTETracker): Attributes: proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. - encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled. + encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled. gmc (GMC): An instance of the GMC algorithm for data association. - args (object): Parsed command-line arguments containing tracking parameters. + args (Any): Parsed command-line arguments containing tracking parameters. Methods: get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking. @@ -139,17 +161,29 @@ class BOTSORT(BYTETracker): get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID. multi_predict(tracks): Predict and track multiple objects with YOLOv8 model. - Usage: - bot_sort = BOTSORT(args, frame_rate) - bot_sort.init_track(dets, scores, cls, img) - bot_sort.multi_predict(tracks) + Examples: + Initialize BOTSORT and process detections + >>> bot_sort = BOTSORT(args, frame_rate=30) + >>> bot_sort.init_track(dets, scores, cls, img) + >>> bot_sort.multi_predict(tracks) Note: The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args. """ def __init__(self, args, frame_rate=30): - """Initialize YOLOv8 object with ReID module and GMC algorithm.""" + """ + Initialize YOLOv8 object with ReID module and GMC algorithm. + + Args: + args (object): Parsed command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video being processed. + + Examples: + Initialize BOTSORT with command-line arguments and a specified frame rate: + >>> args = parse_args() + >>> bot_sort = BOTSORT(args, frame_rate=30) + """ super().__init__(args, frame_rate) # ReID module self.proximity_thresh = args.proximity_thresh @@ -161,11 +195,11 @@ class BOTSORT(BYTETracker): self.gmc = GMC(method=args.gmc_method) def get_kalmanfilter(self): - """Returns an instance of KalmanFilterXYWH for object tracking.""" + """Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.""" return KalmanFilterXYWH() def init_track(self, dets, scores, cls, img=None): - """Initialize track with detections, scores, and classes.""" + """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.""" if len(dets) == 0: return [] if self.args.with_reid and self.encoder is not None: @@ -175,7 +209,7 @@ class BOTSORT(BYTETracker): return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections def get_dists(self, tracks, detections): - """Get distances between tracks and detections using IoU and (optionally) ReID embeddings.""" + """Calculates distances between tracks and detections using IoU and optionally ReID embeddings.""" dists = matching.iou_distance(tracks, detections) dists_mask = dists > self.proximity_thresh @@ -190,10 +224,10 @@ class BOTSORT(BYTETracker): return dists def multi_predict(self, tracks): - """Predict and track multiple objects with YOLOv8 model.""" + """Predicts the mean and covariance of multiple object tracks using a shared Kalman filter.""" BOTrack.multi_predict(tracks) def reset(self): - """Reset tracker.""" + """Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.""" super().reset() self.gmc.reset_params() diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py index 7b4dc00f..6fd39466 100644 --- a/ultralytics/trackers/byte_tracker.py +++ b/ultralytics/trackers/byte_tracker.py @@ -25,7 +25,7 @@ class STrack(BaseTrack): is_activated (bool): Boolean flag indicating if the track has been activated. score (float): Confidence score of the track. tracklet_len (int): Length of the tracklet. - cls (any): Class label for the object. + cls (Any): Class label for the object. idx (int): Index or identifier for the object. frame_id (int): Current frame ID. start_frame (int): Frame where the object was first detected. @@ -39,12 +39,31 @@ class STrack(BaseTrack): update(new_track, frame_id): Update the state of a matched track. convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. + + Examples: + Initialize and activate a new track + >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls='person') + >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1) """ shared_kalman = KalmanFilterXYAH() def __init__(self, xywh, score, cls): - """Initialize new STrack instance.""" + """ + Initialize a new STrack instance. + + Args: + xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where + (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id. + score (float): Confidence score of the detection. + cls (Any): Class label for the detected object. + + Examples: + >>> xywh = [100.0, 150.0, 50.0, 75.0, 1] + >>> score = 0.9 + >>> cls = 'person' + >>> track = STrack(xywh, score, cls) + """ super().__init__() # xywh+idx or xywha+idx assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" @@ -60,7 +79,7 @@ class STrack(BaseTrack): self.angle = xywh[4] if len(xywh) == 6 else None def predict(self): - """Predicts mean and covariance using Kalman filter.""" + """Predicts the next state (mean and covariance) of the object using the Kalman filter.""" mean_state = self.mean.copy() if self.state != TrackState.Tracked: mean_state[7] = 0 @@ -68,7 +87,7 @@ class STrack(BaseTrack): @staticmethod def multi_predict(stracks): - """Perform multi-object predictive tracking using Kalman filter for given stracks.""" + """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.""" if len(stracks) <= 0: return multi_mean = np.asarray([st.mean.copy() for st in stracks]) @@ -83,7 +102,7 @@ class STrack(BaseTrack): @staticmethod def multi_gmc(stracks, H=np.eye(2, 3)): - """Update state tracks positions and covariances using a homography matrix.""" + """Update state tracks positions and covariances using a homography matrix for multiple tracks.""" if len(stracks) > 0: multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks]) @@ -101,7 +120,7 @@ class STrack(BaseTrack): stracks[i].covariance = cov def activate(self, kalman_filter, frame_id): - """Start a new tracklet.""" + """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.""" self.kalman_filter = kalman_filter self.track_id = self.next_id() self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh)) @@ -114,7 +133,7 @@ class STrack(BaseTrack): self.start_frame = frame_id def re_activate(self, new_track, frame_id, new_id=False): - """Reactivates a previously lost track with a new detection.""" + """Reactivates a previously lost track using new detection data and updates its state and attributes.""" self.mean, self.covariance = self.kalman_filter.update( self.mean, self.covariance, self.convert_coords(new_track.tlwh) ) @@ -136,6 +155,12 @@ class STrack(BaseTrack): Args: new_track (STrack): The new track containing updated information. frame_id (int): The ID of the current frame. + + Examples: + Update the state of a track with new detection information + >>> track = STrack([100, 200, 50, 80, 0.9, 1]) + >>> new_track = STrack([105, 205, 55, 85, 0.95, 1]) + >>> track.update(new_track, 2) """ self.frame_id = frame_id self.tracklet_len += 1 @@ -158,7 +183,7 @@ class STrack(BaseTrack): @property def tlwh(self): - """Get current position in bounding box format (top left x, top left y, width, height).""" + """Returns the bounding box in top-left-width-height format from the current state estimate.""" if self.mean is None: return self._tlwh.copy() ret = self.mean[:4].copy() @@ -168,16 +193,14 @@ class STrack(BaseTrack): @property def xyxy(self): - """Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right).""" + """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.""" ret = self.tlwh.copy() ret[2:] += ret[:2] return ret @staticmethod def tlwh_to_xyah(tlwh): - """Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width / - height. - """ + """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.""" ret = np.asarray(tlwh).copy() ret[:2] += ret[2:] / 2 ret[2] /= ret[3] @@ -185,14 +208,14 @@ class STrack(BaseTrack): @property def xywh(self): - """Get current position in bounding box format (center x, center y, width, height).""" + """Returns the current position of the bounding box in (center x, center y, width, height) format.""" ret = np.asarray(self.tlwh).copy() ret[:2] += ret[2:] / 2 return ret @property def xywha(self): - """Get current position in bounding box format (center x, center y, width, height, angle).""" + """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing.""" if self.angle is None: LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.") return self.xywh @@ -200,12 +223,12 @@ class STrack(BaseTrack): @property def result(self): - """Get current tracking results.""" + """Returns the current tracking results in the appropriate bounding box format.""" coords = self.xyxy if self.angle is None else self.xywha return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] def __repr__(self): - """Return a string representation of the BYTETracker object with start and end frames and track ID.""" + """Returns a string representation of the STrack object including start frame, end frame, and track ID.""" return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" @@ -213,18 +236,18 @@ class BYTETracker: """ BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. - The class is responsible for initializing, updating, and managing the tracks for detected objects in a video - sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for - predicting the new object locations, and performs data association. + Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence. + It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting + the new object locations, and performs data association. Attributes: - tracked_stracks (list[STrack]): List of successfully activated tracks. - lost_stracks (list[STrack]): List of lost tracks. - removed_stracks (list[STrack]): List of removed tracks. + tracked_stracks (List[STrack]): List of successfully activated tracks. + lost_stracks (List[STrack]): List of lost tracks. + removed_stracks (List[STrack]): List of removed tracks. frame_id (int): The current frame ID. - args (namespace): Command-line arguments. + args (Namespace): Command-line arguments. max_time_lost (int): The maximum frames for a track to be considered as 'lost'. - kalman_filter (object): Kalman Filter object. + kalman_filter (KalmanFilterXYAH): Kalman Filter object. Methods: update(results, img=None): Updates object tracker with new detections. @@ -236,10 +259,27 @@ class BYTETracker: joint_stracks(tlista, tlistb): Combines two lists of stracks. sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list. remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU. + + Examples: + Initialize BYTETracker and update with detection results + >>> tracker = BYTETracker(args, frame_rate=30) + >>> results = yolo_model.detect(image) + >>> tracked_objects = tracker.update(results) """ def __init__(self, args, frame_rate=30): - """Initialize a YOLOv8 object to track objects with given arguments and frame rate.""" + """ + Initialize a BYTETracker instance for object tracking. + + Args: + args (Namespace): Command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video sequence. + + Examples: + Initialize BYTETracker with command-line arguments and a frame rate of 30 + >>> args = Namespace(track_buffer=30) + >>> tracker = BYTETracker(args, frame_rate=30) + """ self.tracked_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack] @@ -251,7 +291,7 @@ class BYTETracker: self.reset_id() def update(self, results, img=None): - """Updates object tracker with new detections and returns tracked object bounding boxes.""" + """Updates the tracker with new detections and returns the current list of tracked objects.""" self.frame_id += 1 activated_stracks = [] refind_stracks = [] @@ -365,31 +405,31 @@ class BYTETracker: return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) def get_kalmanfilter(self): - """Returns a Kalman filter object for tracking bounding boxes.""" + """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.""" return KalmanFilterXYAH() def init_track(self, dets, scores, cls, img=None): - """Initialize object tracking with detections and scores using STrack algorithm.""" + """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm.""" return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections def get_dists(self, tracks, detections): - """Calculates the distance between tracks and detections using IoU and fuses scores.""" + """Calculates the distance between tracks and detections using IoU and optionally fuses scores.""" dists = matching.iou_distance(tracks, detections) if self.args.fuse_score: dists = matching.fuse_score(dists, detections) return dists def multi_predict(self, tracks): - """Returns the predicted tracks using the YOLOv8 network.""" + """Predict the next states for multiple tracks using Kalman filter.""" STrack.multi_predict(tracks) @staticmethod def reset_id(): - """Resets the ID counter of STrack.""" + """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions.""" STrack.reset_id() def reset(self): - """Reset tracker.""" + """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.""" self.tracked_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack] @@ -399,7 +439,7 @@ class BYTETracker: @staticmethod def joint_stracks(tlista, tlistb): - """Combine two lists of stracks into a single one.""" + """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.""" exists = {} res = [] for t in tlista: @@ -414,20 +454,13 @@ class BYTETracker: @staticmethod def sub_stracks(tlista, tlistb): - """DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/ - stracks = {t.track_id: t for t in tlista} - for t in tlistb: - tid = t.track_id - if stracks.get(tid, 0): - del stracks[tid] - return list(stracks.values()) - """ + """Filters out the stracks present in the second list from the first list.""" track_ids_b = {t.track_id for t in tlistb} return [t for t in tlista if t.track_id not in track_ids_b] @staticmethod def remove_duplicate_stracks(stracksa, stracksb): - """Remove duplicate stracks with non-maximum IoU distance.""" + """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance.""" pdist = matching.iou_distance(stracksa, stracksb) pairs = np.where(pdist < 0.15) dupa, dupb = [], [] diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index 151abaaf..b0103cf9 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -21,10 +21,15 @@ def on_predict_start(predictor: object, persist: bool = False) -> None: Args: predictor (object): The predictor object to initialize trackers for. - persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. + persist (bool): Whether to persist the trackers if they already exist. Raises: AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. + + Examples: + Initialize trackers for a predictor object: + >>> predictor = SomePredictorClass() + >>> on_predict_start(predictor, persist=True) """ if hasattr(predictor, "trackers") and persist: return @@ -51,7 +56,12 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None Args: predictor (object): The predictor object containing the predictions. - persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. + persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Postprocess predictions and update with tracking + >>> predictor = YourPredictorClass() + >>> on_predict_postprocess_end(predictor, persist=True) """ path, im0s = predictor.batch[:2] @@ -84,6 +94,11 @@ def register_tracker(model: object, persist: bool) -> None: Args: model (object): The model object to register tracking callbacks for. persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Register tracking callbacks to a YOLO model + >>> model = YOLOModel() + >>> register_tracker(model, persist=True) """ model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py index bc7d96d4..9b98de65 100644 --- a/ultralytics/trackers/utils/gmc.py +++ b/ultralytics/trackers/utils/gmc.py @@ -19,27 +19,39 @@ class GMC: method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. downscale (int): Factor by which to downscale the frames for processing. prevFrame (np.ndarray): Stores the previous frame for tracking. - prevKeyPoints (list): Stores the keypoints from the previous frame. + prevKeyPoints (List): Stores the keypoints from the previous frame. prevDescriptors (np.ndarray): Stores the descriptors from the previous frame. initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. Methods: - __init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method - and downscale factor. - apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses - provided detections. - applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame. - applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame. - applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame. + __init__: Initializes a GMC object with the specified method and downscale factor. + apply: Applies the chosen method to a raw frame and optionally uses provided detections. + applyEcc: Applies the ECC algorithm to a raw frame. + applyFeatures: Applies feature-based methods like ORB or SIFT to a raw frame. + applySparseOptFlow: Applies the Sparse Optical Flow method to a raw frame. + reset_params: Resets the internal parameters of the GMC object. + + Examples: + Create a GMC object and apply it to a frame + >>> gmc = GMC(method='sparseOptFlow', downscale=2) + >>> frame = np.array([[1, 2, 3], [4, 5, 6]]) + >>> processed_frame = gmc.apply(frame) + >>> print(processed_frame) + array([[1, 2, 3], + [4, 5, 6]]) """ def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: """ - Initialize a video tracker with specified parameters. + Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor. Args: method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. downscale (int): Downscale factor for processing frames. + + Examples: + Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2 + >>> gmc = GMC(method='sparseOptFlow', downscale=2) """ super().__init__() @@ -79,20 +91,21 @@ class GMC: def apply(self, raw_frame: np.array, detections: list = None) -> np.array: """ - Apply object detection on a raw frame using specified method. + Apply object detection on a raw frame using the specified method. Args: - raw_frame (np.ndarray): The raw frame to be processed. - detections (list): List of detections to be used in the processing. + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. Returns: - (np.ndarray): Processed frame. + (np.ndarray): Processed frame with applied object detection. Examples: - >>> gmc = GMC() - >>> gmc.apply(np.array([[1, 2, 3], [4, 5, 6]])) - array([[1, 2, 3], - [4, 5, 6]]) + >>> gmc = GMC(method='sparseOptFlow') + >>> raw_frame = np.random.rand(480, 640, 3) + >>> processed_frame = gmc.apply(raw_frame) + >>> print(processed_frame.shape) + (480, 640, 3) """ if self.method in {"orb", "sift"}: return self.applyFeatures(raw_frame, detections) @@ -105,19 +118,20 @@ class GMC: def applyEcc(self, raw_frame: np.array) -> np.array: """ - Apply ECC algorithm to a raw frame. + Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation. Args: - raw_frame (np.ndarray): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). Returns: - (np.ndarray): Processed frame. + (np.ndarray): The processed frame with the applied ECC transformation. Examples: - >>> gmc = GMC() - >>> gmc.applyEcc(np.array([[1, 2, 3], [4, 5, 6]])) - array([[1, 2, 3], - [4, 5, 6]]) + >>> gmc = GMC(method='ecc') + >>> processed_frame = gmc.applyEcc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(processed_frame) + [[1. 0. 0.] + [0. 1. 0.]] """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) @@ -127,8 +141,6 @@ class GMC: if self.downscale > 1.0: frame = cv2.GaussianBlur(frame, (3, 3), 1.5) frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) - width = width // self.downscale - height = height // self.downscale # Handle first frame if not self.initializedFirstFrame: @@ -154,17 +166,18 @@ class GMC: Apply feature-based methods like ORB or SIFT to a raw frame. Args: - raw_frame (np.ndarray): The raw frame to be processed. - detections (list): List of detections to be used in the processing. + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. Returns: (np.ndarray): Processed frame. Examples: - >>> gmc = GMC() - >>> gmc.applyFeatures(np.array([[1, 2, 3], [4, 5, 6]])) - array([[1, 2, 3], - [4, 5, 6]]) + >>> gmc = GMC(method='orb') + >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_frame = gmc.applyFeatures(raw_frame) + >>> print(processed_frame.shape) + (2, 3) """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) @@ -296,16 +309,17 @@ class GMC: Apply Sparse Optical Flow method to a raw frame. Args: - raw_frame (np.ndarray): The raw frame to be processed. + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). Returns: - (np.ndarray): Processed frame. + (np.ndarray): Processed frame with shape (2, 3). Examples: >>> gmc = GMC() - >>> gmc.applySparseOptFlow(np.array([[1, 2, 3], [4, 5, 6]])) - array([[1, 2, 3], - [4, 5, 6]]) + >>> result = gmc.applySparseOptFlow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(result) + [[1. 0. 0.] + [0. 1. 0.]] """ height, width, _ = raw_frame.shape frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) @@ -356,7 +370,7 @@ class GMC: return H def reset_params(self) -> None: - """Reset parameters.""" + """Reset the internal parameters including previous frame, keypoints, and descriptors.""" self.prevFrame = None self.prevKeyPoints = None self.prevDescriptors = None diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py index 36e280f6..24b67d47 100644 --- a/ultralytics/trackers/utils/kalman_filter.py +++ b/ultralytics/trackers/utils/kalman_filter.py @@ -6,17 +6,49 @@ import scipy.linalg class KalmanFilterXYAH: """ - For bytetrack. A simple Kalman filter for tracking bounding boxes in image space. + A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter. - The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect - ratio a, height h, and their respective velocities. + Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space + (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their + respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is + taken as a direct observation of the state space (linear observation model). - Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct - observation of the state space (linear observation model). + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step (vectorized version). + update: Runs the Kalman filter correction step. + gating_distance: Computes the gating distance between state distribution and measurements. + + Examples: + Initialize the Kalman filter and create a track from a measurement + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 200, 1.5, 50]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) """ def __init__(self): - """Initialize Kalman filter model matrices with motion and observation uncertainty weights.""" + """ + Initialize Kalman filter model matrices with motion and observation uncertainty weights. + + The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y) + represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective + velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear + observation model for bounding box location. + + Examples: + Initialize a Kalman filter for tracking: + >>> kf = KalmanFilterXYAH() + """ ndim, dt = 4, 1.0 # Create Kalman filter model matrices @@ -32,15 +64,20 @@ class KalmanFilterXYAH: def initiate(self, measurement: np.ndarray) -> tuple: """ - Create track from unassociated measurement. + Create a track from an unassociated measurement. Args: measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, and height h. Returns: - (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) + (tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 50, 1.5, 200]) + >>> mean, covariance = kf.initiate(measurement) """ mean_pos = measurement mean_vel = np.zeros_like(mean_pos) @@ -64,12 +101,18 @@ class KalmanFilterXYAH: Run Kalman filter prediction step. Args: - mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. - covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. + mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. Returns: (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) """ std_pos = [ self._std_weight_position * mean[3], @@ -100,6 +143,12 @@ class KalmanFilterXYAH: Returns: (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_covariance = kf.project(mean, covariance) """ std = [ self._std_weight_position * mean[3], @@ -115,15 +164,21 @@ class KalmanFilterXYAH: def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: """ - Run Kalman filter prediction step (Vectorized version). + Run Kalman filter prediction step for multiple object states (Vectorized version). Args: mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. Returns: - (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved - velocities are initialized to 0 mean. + (tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states. + The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities + are initialized to 0 mean. + + Examples: + >>> mean = np.random.rand(10, 8) # 10 object states + >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states + >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance) """ std_pos = [ self._std_weight_position * mean[:, 3], @@ -160,6 +215,13 @@ class KalmanFilterXYAH: Returns: (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([1, 1, 1, 1]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) """ projected_mean, projected_cov = self.project(mean, covariance) @@ -182,23 +244,31 @@ class KalmanFilterXYAH: metric: str = "maha", ) -> np.ndarray: """ - Compute gating distance between state distribution and measurements. A suitable distance threshold can be - obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom, - otherwise 2. + Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square + distribution has 4 degrees of freedom, otherwise 2. Args: mean (ndarray): Mean vector over the state distribution (8 dimensional). covariance (ndarray): Covariance of the state distribution (8x8 dimensional). - measurements (ndarray): An Nx4 matrix of N measurements, each in format (x, y, a, h) where (x, y) - is the bounding box center position, a the aspect ratio, and h the height. - only_position (bool, optional): If True, distance computation is done with respect to the bounding box - center position only. Defaults to False. - metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the - squared Euclidean distance and 'maha' for the squared Mahalanobis distance. Defaults to 'maha'. + measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the + bounding box center position, a the aspect ratio, and h the height. + only_position (bool): If True, distance computation is done with respect to box center position only. + metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared + Euclidean distance and 'maha' for the squared Mahalanobis distance. Returns: (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between (mean, covariance) and `measurements[i]`. + + Examples: + Compute gating distance using Mahalanobis metric: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]]) + >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric='maha') """ mean, covariance = self.project(mean, covariance) if only_position: @@ -218,13 +288,33 @@ class KalmanFilterXYAH: class KalmanFilterXYWH(KalmanFilterXYAH): """ - For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space. + A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter. - The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width - w, height h, and their respective velocities. - - Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct + Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where + (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. + The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct observation of the state space (linear observation model). + + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step in a vectorized manner. + update: Runs the Kalman filter correction step. + + Examples: + Create a Kalman filter and initialize a track + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) """ def initiate(self, measurement: np.ndarray) -> tuple: @@ -237,6 +327,22 @@ class KalmanFilterXYWH(KalmanFilterXYAH): Returns: (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + [100. 50. 20. 40. 0. 0. 0. 0.] + >>> print(covariance) + [[ 4. 0. 0. 0. 0. 0. 0. 0.] + [ 0. 4. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 4. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 4. 0. 0. 0. 0.] + [ 0. 0. 0. 0. 0.25 0. 0. 0.] + [ 0. 0. 0. 0. 0. 0.25 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.25 0.] + [ 0. 0. 0. 0. 0. 0. 0. 0.25]] """ mean_pos = measurement mean_vel = np.zeros_like(mean_pos) @@ -260,12 +366,18 @@ class KalmanFilterXYWH(KalmanFilterXYAH): Run Kalman filter prediction step. Args: - mean (ndarray): The 8 dimensional mean vector of the object state at the previous time step. - covariance (ndarray): The 8x8 dimensional covariance matrix of the object state at the previous time step. + mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. Returns: (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) """ std_pos = [ self._std_weight_position * mean[2], @@ -296,6 +408,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH): Returns: (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_cov = kf.project(mean, covariance) """ std = [ self._std_weight_position * mean[2], @@ -320,6 +438,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH): Returns: (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors + >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices + >>> kf = KalmanFilterXYWH() + >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance) """ std_pos = [ self._std_weight_position * mean[:, 2], @@ -356,5 +480,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH): Returns: (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([0.5, 0.5, 1.2, 1.2]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) """ return super().update(mean, covariance, measurement) diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index 222c3a5c..e7a60a45 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -19,18 +19,23 @@ except (ImportError, AssertionError, AttributeError): def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: """ - Perform linear assignment using scipy or lap.lapjv. + Perform linear assignment using either the scipy or lap.lapjv method. Args: - cost_matrix (np.ndarray): The matrix containing cost values for assignments. + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). thresh (float): Threshold for considering an assignment valid. - use_lap (bool, optional): Whether to use lap.lapjv. Defaults to True. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. Returns: - Tuple with: - - matched indices - - unmatched indices from 'a' - - unmatched indices from 'b' + (tuple): A tuple containing: + - matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + - unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + - unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) """ if cost_matrix.size == 0: @@ -68,6 +73,12 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray: Returns: (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) """ if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): @@ -98,12 +109,19 @@ def embedding_distance(tracks: list, detections: list, metric: str = "cosine") - Compute distance between tracks and detections based on embeddings. Args: - tracks (list[STrack]): List of tracks. - detections (list[BaseTrack]): List of detections. - metric (str, optional): Metric for distance computation. Defaults to 'cosine'. + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. Returns: - (np.ndarray): Cost matrix computed based on embeddings. + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric='cosine') """ cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) @@ -122,11 +140,17 @@ def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: Fuses cost matrix with detection scores to produce a single similarity matrix. Args: - cost_matrix (np.ndarray): The matrix containing cost values for assignments. - detections (list[BaseTrack]): List of detections with scores. + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. Returns: - (np.ndarray): Fused similarity matrix. + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) """ if cost_matrix.size == 0: diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 9d0f562e..f614eea2 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -47,7 +47,7 @@ PYTHON_VERSION = platform.python_version() TORCH_VERSION = torch.__version__ TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision HELP_MSG = """ - Usage examples for running Ultralytics YOLO: + Examples for running Ultralytics: 1. Install the ultralytics package: diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py index 719cacae..e710c903 100644 --- a/ultralytics/utils/files.py +++ b/ultralytics/utils/files.py @@ -11,19 +11,44 @@ from pathlib import Path class WorkingDirectory(contextlib.ContextDecorator): - """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" + """ + A context manager and decorator for temporarily changing the working directory. + + This class allows for the temporary change of the working directory using a context manager or decorator. + It ensures that the original working directory is restored after the context or decorated function completes. + + Attributes: + dir (Path): The new directory to switch to. + cwd (Path): The original current working directory before the switch. + + Methods: + __enter__: Changes the current directory to the specified directory. + __exit__: Restores the original working directory on context exit. + + Examples: + Using as a context manager: + >>> with WorkingDirectory('/path/to/new/dir'): + >>> # Perform operations in the new directory + >>> pass + + Using as a decorator: + >>> @WorkingDirectory('/path/to/new/dir') + >>> def some_function(): + >>> # Perform operations in the new directory + >>> pass + """ def __init__(self, new_dir): - """Sets the working directory to 'new_dir' upon instantiation.""" + """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators.""" self.dir = new_dir # new dir self.cwd = Path.cwd().resolve() # current dir def __enter__(self): - """Changes the current directory to the specified directory.""" + """Changes the current working directory to the specified directory upon entering the context.""" os.chdir(self.dir) def __exit__(self, exc_type, exc_val, exc_tb): # noqa - """Restore the current working directory on context exit.""" + """Restores the original working directory when exiting the context.""" os.chdir(self.cwd) @@ -35,18 +60,16 @@ def spaces_in_path(path): file/directory back to its original location. Args: - path (str | Path): The original path. + path (str | Path): The original path that may contain spaces. Yields: (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path. - Example: - ```python - with ultralytics.utils.files import spaces_in_path - - with spaces_in_path('/path/with spaces') as new_path: - # Your code here - ``` + Examples: + Use the context manager to handle paths with spaces: + >>> from ultralytics.utils.files import spaces_in_path + >>> with spaces_in_path('/path/with spaces') as new_path: + >>> # Your code here """ # If path has spaces, replace them with underscores @@ -84,21 +107,35 @@ def spaces_in_path(path): def increment_path(path, exist_ok=False, sep="", mkdir=False): """ - Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. - If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to + If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the - number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a + number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a directory if it does not already exist. Args: - path (str, pathlib.Path): Path to increment. - exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. Defaults to False. - sep (str, optional): Separator to use between the path and the incrementation number. Defaults to ''. - mkdir (bool, optional): Create a directory if it does not exist. Defaults to False. + path (str | pathlib.Path): Path to increment. + exist_ok (bool): If True, the path will not be incremented and returned as-is. + sep (str): Separator to use between the path and the incrementation number. + mkdir (bool): Create a directory if it does not exist. Returns: (pathlib.Path): Incremented path. + + Examples: + Increment a directory path: + >>> from pathlib import Path + >>> path = Path("runs/exp") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp2 + + Increment a file path: + >>> path = Path("runs/exp/results.txt") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp/results2.txt """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: @@ -118,19 +155,19 @@ def increment_path(path, exist_ok=False, sep="", mkdir=False): def file_age(path=__file__): - """Return days since last file update.""" + """Return days since the last modification of the specified file.""" dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta return dt.days # + dt.seconds / 86400 # fractional days def file_date(path=__file__): - """Return human-readable file modification date, i.e. '2021-3-26'.""" + """Returns the file modification date in 'YYYY-M-D' format.""" t = datetime.fromtimestamp(Path(path).stat().st_mtime) return f"{t.year}-{t.month}-{t.day}" def file_size(path): - """Return file/dir size (MB).""" + """Returns the size of a file or directory in megabytes (MB).""" if isinstance(path, (str, Path)): mb = 1 << 20 # bytes to MiB (1024 ** 2) path = Path(path) @@ -142,7 +179,7 @@ def file_size(path): def get_latest_run(search_dir="."): - """Return path to most recent 'last.pt' in /runs (i.e. to --resume from).""" + """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training.""" last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) return max(last_list, key=os.path.getctime) if last_list else "" @@ -152,17 +189,15 @@ def update_models(model_names=("yolov8n.pt",), source_dir=Path("."), update_name Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. Args: - model_names (tuple, optional): Model filenames to update, defaults to ("yolov8n.pt"). - source_dir (Path, optional): Directory containing models and target subdirectory, defaults to current directory. - update_names (bool, optional): Update model names from a data YAML. + model_names (Tuple[str, ...]): Model filenames to update. + source_dir (Path): Directory containing models and target subdirectory. + update_names (bool): Update model names from a data YAML. - Example: - ```python - from ultralytics.utils.files import update_models - - model_names = (f"rtdetr-{size}.pt" for size in "lx") - update_models(model_names) - ``` + Examples: + Update specified YOLO models and save them in 'updated_models' subdirectory: + >>> from ultralytics.utils.files import update_models + >>> model_names = ("yolov8n.pt", "yolov8s.pt") + >>> update_models(model_names, source_dir=Path("/models"), update_names=True) """ from ultralytics import YOLO from ultralytics.nn.autobackend import default_class_names