Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -12,6 +12,33 @@ from .utils.kalman_filter import KalmanFilterXYWH
|
|||
|
||||
|
||||
class BOTrack(STrack):
|
||||
"""
|
||||
An extended version of the STrack class for YOLOv8, adding object tracking features.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
|
||||
smooth_feat (np.ndarray): Smoothed feature vector.
|
||||
curr_feat (np.ndarray): Current feature vector.
|
||||
features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
|
||||
alpha (float): Smoothing factor for the exponential moving average of features.
|
||||
mean (np.ndarray): The mean state of the Kalman filter.
|
||||
covariance (np.ndarray): The covariance matrix of the Kalman filter.
|
||||
|
||||
Methods:
|
||||
update_features(feat): Update features vector and smooth it using exponential moving average.
|
||||
predict(): Predicts the mean and covariance using Kalman filter.
|
||||
re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID.
|
||||
update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID.
|
||||
tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
|
||||
multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter.
|
||||
convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format.
|
||||
tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`.
|
||||
|
||||
Usage:
|
||||
bo_track = BOTrack(tlwh, score, cls, feat)
|
||||
bo_track.predict()
|
||||
bo_track.update(new_track, frame_id)
|
||||
"""
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
|
|
@ -59,9 +86,7 @@ class BOTrack(STrack):
|
|||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
"""Get current position in bounding box format `(top left x, top left y, width, height)`."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -90,15 +115,37 @@ class BOTrack(STrack):
|
|||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width,
|
||||
height)`.
|
||||
"""
|
||||
"""Convert bounding box to format `(center x, center y, width, height)`."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
|
||||
class BOTSORT(BYTETracker):
|
||||
"""
|
||||
An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm.
|
||||
|
||||
Attributes:
|
||||
proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
|
||||
appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
|
||||
encoder (object): Object to handle ReID embeddings, set to None if ReID is not enabled.
|
||||
gmc (GMC): An instance of the GMC algorithm for data association.
|
||||
args (object): Parsed command-line arguments containing tracking parameters.
|
||||
|
||||
Methods:
|
||||
get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking.
|
||||
init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes.
|
||||
get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID.
|
||||
multi_predict(tracks): Predict and track multiple objects with YOLOv8 model.
|
||||
|
||||
Usage:
|
||||
bot_sort = BOTSORT(args, frame_rate)
|
||||
bot_sort.init_track(dets, scores, cls, img)
|
||||
bot_sort.multi_predict(tracks)
|
||||
|
||||
Note:
|
||||
The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
|
|
|
|||
|
|
@ -8,10 +8,43 @@ from .utils.kalman_filter import KalmanFilterXYAH
|
|||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
"""
|
||||
Single object tracking representation that uses Kalman filtering for state estimation.
|
||||
|
||||
This class is responsible for storing all the information regarding individual tracklets and performs state updates
|
||||
and predictions based on Kalman filter.
|
||||
|
||||
Attributes:
|
||||
shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction.
|
||||
_tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
|
||||
kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
|
||||
mean (np.ndarray): Mean state estimate vector.
|
||||
covariance (np.ndarray): Covariance of state estimate.
|
||||
is_activated (bool): Boolean flag indicating if the track has been activated.
|
||||
score (float): Confidence score of the track.
|
||||
tracklet_len (int): Length of the tracklet.
|
||||
cls (any): Class label for the object.
|
||||
idx (int): Index or identifier for the object.
|
||||
frame_id (int): Current frame ID.
|
||||
start_frame (int): Frame where the object was first detected.
|
||||
|
||||
Methods:
|
||||
predict(): Predict the next state of the object using Kalman filter.
|
||||
multi_predict(stracks): Predict the next states for multiple tracks.
|
||||
multi_gmc(stracks, H): Update multiple track states using a homography matrix.
|
||||
activate(kalman_filter, frame_id): Activate a new tracklet.
|
||||
re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet.
|
||||
update(new_track, frame_id): Update the state of a matched track.
|
||||
convert_coords(tlwh): Convert bounding box to x-y-angle-height format.
|
||||
tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format.
|
||||
tlbr_to_tlwh(tlbr): Convert tlbr bounding box to tlwh format.
|
||||
tlwh_to_tlbr(tlwh): Convert tlwh bounding box to tlbr format.
|
||||
"""
|
||||
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, tlwh, score, cls):
|
||||
"""wait activate."""
|
||||
"""Initialize new STrack instance."""
|
||||
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
|
|
@ -92,10 +125,11 @@ class STrack(BaseTrack):
|
|||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:return:
|
||||
Update the state of a matched track.
|
||||
|
||||
Args:
|
||||
new_track (STrack): The new track containing updated information.
|
||||
frame_id (int): The ID of the current frame.
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
|
@ -116,9 +150,7 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
"""Get current position in bounding box format (top left x, top left y, width, height)."""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
|
|
@ -128,17 +160,15 @@ class STrack(BaseTrack):
|
|||
|
||||
@property
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
"""Convert bounding box to format (min x, min y, max x, max y), i.e., (top left, bottom right)."""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""Convert bounding box to format (center x, center y, aspect ratio, height), where the aspect ratio is width /
|
||||
height.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
|
|
@ -165,6 +195,33 @@ class STrack(BaseTrack):
|
|||
|
||||
|
||||
class BYTETracker:
|
||||
"""
|
||||
BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
|
||||
|
||||
The class is responsible for initializing, updating, and managing the tracks for detected objects in a video
|
||||
sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
|
||||
predicting the new object locations, and performs data association.
|
||||
|
||||
Attributes:
|
||||
tracked_stracks (list[STrack]): List of successfully activated tracks.
|
||||
lost_stracks (list[STrack]): List of lost tracks.
|
||||
removed_stracks (list[STrack]): List of removed tracks.
|
||||
frame_id (int): The current frame ID.
|
||||
args (namespace): Command-line arguments.
|
||||
max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
|
||||
kalman_filter (object): Kalman Filter object.
|
||||
|
||||
Methods:
|
||||
update(results, img=None): Updates object tracker with new detections.
|
||||
get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes.
|
||||
init_track(dets, scores, cls, img=None): Initialize object tracking with detections.
|
||||
get_dists(tracks, detections): Calculates the distance between tracks and detections.
|
||||
multi_predict(tracks): Predicts the location of tracks.
|
||||
reset_id(): Resets the ID counter of STrack.
|
||||
joint_stracks(tlista, tlistb): Combines two lists of stracks.
|
||||
sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list.
|
||||
remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IOU.
|
||||
"""
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
|
|
@ -234,8 +291,7 @@ class BYTETracker:
|
|||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Step 3: Second association, with low score detection boxes
|
||||
# association the untrack to the low score detections
|
||||
# Step 3: Second association, with low score detection boxes association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
# TODO
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ def register_tracker(model, persist):
|
|||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
"""
|
||||
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
||||
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
||||
|
|
|
|||
|
|
@ -9,6 +9,29 @@ from ultralytics.utils import LOGGER
|
|||
|
||||
|
||||
class GMC:
|
||||
"""
|
||||
Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
|
||||
|
||||
This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
|
||||
SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
|
||||
|
||||
Attributes:
|
||||
method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
|
||||
downscale (int): Factor by which to downscale the frames for processing.
|
||||
prevFrame (np.array): Stores the previous frame for tracking.
|
||||
prevKeyPoints (list): Stores the keypoints from the previous frame.
|
||||
prevDescriptors (np.array): Stores the descriptors from the previous frame.
|
||||
initializedFirstFrame (bool): Flag to indicate if the first frame has been processed.
|
||||
|
||||
Methods:
|
||||
__init__(self, method='sparseOptFlow', downscale=2): Initializes a GMC object with the specified method
|
||||
and downscale factor.
|
||||
apply(self, raw_frame, detections=None): Applies the chosen method to a raw frame and optionally uses
|
||||
provided detections.
|
||||
applyEcc(self, raw_frame, detections=None): Applies the ECC algorithm to a raw frame.
|
||||
applyFeatures(self, raw_frame, detections=None): Applies feature-based methods like ORB or SIFT to a raw frame.
|
||||
applySparseOptFlow(self, raw_frame, detections=None): Applies the Sparse Optical Flow method to a raw frame.
|
||||
"""
|
||||
|
||||
def __init__(self, method='sparseOptFlow', downscale=2):
|
||||
"""Initialize a video tracker with specified parameters."""
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ class KalmanFilterXYAH:
|
|||
"""
|
||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
|
||||
aspect ratio a, height h, and their respective velocities.
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect
|
||||
ratio a, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
|
|
@ -182,8 +182,8 @@ class KalmanFilterXYAH:
|
|||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""
|
||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of freedom,
|
||||
otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -223,8 +223,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
"""
|
||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
|
||||
width w, height h, and their respective velocities.
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y), width
|
||||
w, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue