From 0f81777af509f694b2e6c1ffb95486471700d318 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Tue, 18 Feb 2025 19:59:30 +0800 Subject: [PATCH] Fix error with `torch` tensor input in `model.track()` (#19278) Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher --- ultralytics/trackers/track.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py index 6e422f0d..306a884b 100644 --- a/ultralytics/trackers/track.py +++ b/ultralytics/trackers/track.py @@ -66,25 +66,23 @@ def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None >>> predictor = YourPredictorClass() >>> on_predict_postprocess_end(predictor, persist=True) """ - path, im0s = predictor.batch[:2] - is_obb = predictor.args.task == "obb" is_stream = predictor.dataset.mode == "stream" - for i in range(len(im0s)): + for i, result in enumerate(predictor.results): tracker = predictor.trackers[i if is_stream else 0] - vid_path = predictor.save_dir / Path(path[i]).name + vid_path = predictor.save_dir / Path(result.path).name if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: tracker.reset() predictor.vid_path[i if is_stream else 0] = vid_path - det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() + det = (result.obb if is_obb else result.boxes).cpu().numpy() if len(det) == 0: continue - tracks = tracker.update(det, im0s[i]) + tracks = tracker.update(det, result.orig_img) if len(tracks) == 0: continue idx = tracks[:, -1].astype(int) - predictor.results[i] = predictor.results[i][idx] + predictor.results[i] = result[idx] update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])} predictor.results[i].update(**update_args)