Integrate ByteTracker and BoT-SORT trackers (#788)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing 2023-02-16 00:23:03 +08:00 committed by GitHub
parent d99e04daa1
commit ed6c54da7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 1635 additions and 19 deletions

View file

@ -155,7 +155,8 @@ class YOLO:
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["mode"] = kwargs.get("mode", "predict")
assert overrides["mode"] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
@ -165,6 +166,16 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker.track import register_tracker
register_tracker(self)
# bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1
kwargs['conf'] = conf
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""