Integrate ByteTracker and BoT-SORT trackers (#788)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing 2023-02-16 00:23:03 +08:00 committed by GitHub
parent d99e04daa1
commit ed6c54da7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 1635 additions and 19 deletions

View file

@ -44,6 +44,14 @@ class Results:
setattr(r, item, getattr(self, item)[idx])
return r
def update(self, boxes=None, masks=None, probs=None):
if boxes is not None:
self.boxes = Boxes(boxes, self.orig_shape)
if masks is not None:
self.masks = Masks(masks, self.orig_shape)
if boxes is not None:
self.probs = probs
def cpu(self):
r = Results(orig_shape=self.orig_shape)
for item in self.comp:
@ -138,7 +146,10 @@ class Boxes:
def __init__(self, boxes, orig_shape) -> None:
if boxes.ndim == 1:
boxes = boxes[None, :]
assert boxes.shape[-1] == 6 # xyxy, conf, cls
n = boxes.shape[-1]
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
# TODO
self.is_track = n == 7
self.boxes = boxes
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
else np.asarray(orig_shape)
@ -155,6 +166,10 @@ class Boxes:
def cls(self):
return self.boxes[:, -1]
@property
def id(self):
return self.boxes[:, -3] if self.is_track else None
@property
@lru_cache(maxsize=2) # maxsize 1 should suffice
def xywh(self):
@ -303,7 +318,7 @@ class Masks:
def __getitem__(self, idx):
masks = self.masks[idx]
return Masks(masks, self.im_shape, self.orig_shape)
return Masks(masks, self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__