Add utils.ops and nn.modules to tests (#4484)
This commit is contained in:
parent
1cec0185a1
commit
6da8f7f51e
14 changed files with 246 additions and 330 deletions
|
|
@ -6,20 +6,13 @@ import scipy.linalg
|
|||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
For bytetrack. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
The 8-dimensional state space (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y),
|
||||
aspect ratio a, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, a, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -32,14 +25,14 @@ class KalmanFilterXYAH:
|
|||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
# Motion and observation uncertainty are chosen relative to the current state estimate. These weights control
|
||||
# the amount of uncertainty in the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -53,7 +46,6 @@ class KalmanFilterXYAH:
|
|||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
|
|
@ -67,23 +59,21 @@ class KalmanFilterXYAH:
|
|||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||
|
|
@ -100,7 +90,8 @@ class KalmanFilterXYAH:
|
|||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -112,9 +103,7 @@ class KalmanFilterXYAH:
|
|||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||
|
|
@ -126,20 +115,21 @@ class KalmanFilterXYAH:
|
|||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||
|
|
@ -159,7 +149,8 @@ class KalmanFilterXYAH:
|
|||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -168,15 +159,13 @@ class KalmanFilterXYAH:
|
|||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center position, a the aspect
|
||||
ratio, and h the height of the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
|
|
@ -191,10 +180,11 @@ class KalmanFilterXYAH:
|
|||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
"""
|
||||
Compute gating distance between state distribution and measurements. A suitable distance threshold can be
|
||||
obtained from `chi2inv95`. If `only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
|
|
@ -202,18 +192,16 @@ class KalmanFilterXYAH:
|
|||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
An Nx4 dimensional matrix of N measurements, each in format (x, y, a, h) where (x, y) is the bounding box
|
||||
center position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
If True, distance computation is done with respect to the bounding box center position only.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
Returns an array of length N, where the i-th element contains the squared Mahalanobis distance between
|
||||
(mean, covariance) and `measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
|
|
@ -233,38 +221,29 @@ class KalmanFilterXYAH:
|
|||
|
||||
class KalmanFilterXYWH(KalmanFilterXYAH):
|
||||
"""
|
||||
For BoT-SORT
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
For BoT-SORT. A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, w, h, vx, vy, vw, vh
|
||||
|
||||
contains the bounding box center position (x, y), width w, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
The 8-dimensional state space (x, y, w, h, vx, vy, vw, vh) contains the bounding box center position (x, y),
|
||||
width w, height h, and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location (x, y, w, h) is taken as direct
|
||||
observation of the state space (linear observation model).
|
||||
"""
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
"""
|
||||
Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
||||
width w, and height h.
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y), width w, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) of the new track.
|
||||
Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
|
|
@ -279,23 +258,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
"""
|
||||
Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
The 8 dimensional mean vector of the object state at the previous time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
The 8x8 dimensional covariance matrix of the object state at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
|
|
@ -311,7 +288,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
"""
|
||||
Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -323,9 +301,7 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
Returns the projected mean and covariance matrix of the given state estimate.
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
|
|
@ -337,20 +313,21 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
"""
|
||||
Run Kalman filter prediction step (Vectorized version).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
The Nx8 dimensional mean matrix of the object states at the previous time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
Returns the mean vector and covariance matrix of the predicted state. Unobserved velocities are
|
||||
initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||
|
|
@ -370,7 +347,8 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
"""
|
||||
Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
|
@ -379,14 +357,12 @@ class KalmanFilterXYWH(KalmanFilterXYAH):
|
|||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
||||
is the center position, w the width, and h the height of the
|
||||
bounding box.
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center position, w the width,
|
||||
and h the height of the bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
return super().update(mean, covariance, measurement)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue