ultralytics 8.0.71 updates and fixes (#1907)
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Pavel Bugneac <50273042+pavelbugneac@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c38b17a0d8
commit
4e997013bc
19 changed files with 103 additions and 39 deletions
|
|
@ -1,5 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, GPL-3.0 license
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
||||
|
|
@ -10,7 +12,19 @@ from .trackers import BOTSORT, BYTETracker
|
|||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
def on_predict_start(predictor, persist=False):
|
||||
"""
|
||||
Initialize trackers for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
predictor (object): The predictor object to initialize trackers for.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
"""
|
||||
if hasattr(predictor, 'trackers') and persist:
|
||||
return
|
||||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
||||
|
|
@ -38,6 +52,14 @@ def on_predict_postprocess_end(predictor):
|
|||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||
|
||||
|
||||
def register_tracker(model):
|
||||
model.add_callback('on_predict_start', on_predict_start)
|
||||
def register_tracker(model, persist):
|
||||
"""
|
||||
Register tracking callbacks to the model for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
"""
|
||||
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
||||
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue