From e0cbbf409d808b7dbfe2b03e015524acb6605fa6 Mon Sep 17 00:00:00 2001 From: Skillnoob <78843978+Skillnoob@users.noreply.github.com> Date: Sun, 15 Sep 2024 16:59:30 +0200 Subject: [PATCH] Fix `mps.empty_cache()` for macOS without MPS (#16280) Co-authored-by: Glenn Jocher --- ultralytics/engine/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 03965a72..1b104681 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -454,7 +454,7 @@ class BaseTrainer: self.stop |= epoch >= self.epochs # stop if exceeded epochs self.run_callbacks("on_fit_epoch_end") gc.collect() - if MACOS: + if MACOS and self.device.type == "mps": torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy else: torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors @@ -479,7 +479,7 @@ class BaseTrainer: self.plot_metrics() self.run_callbacks("on_train_end") gc.collect() - if MACOS: + if MACOS and self.device.type == "mps": torch.mps.empty_cache() else: torch.cuda.empty_cache()