Fix mps.empty_cache() for macOS without MPS (#16280)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Skillnoob 2024-09-15 16:59:30 +02:00 committed by GitHub
parent 887b46216c
commit e0cbbf409d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -454,7 +454,7 @@ class BaseTrainer:
self.stop |= epoch >= self.epochs # stop if exceeded epochs self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end") self.run_callbacks("on_fit_epoch_end")
gc.collect() gc.collect()
if MACOS: if MACOS and self.device.type == "mps":
torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy
else: else:
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
@ -479,7 +479,7 @@ class BaseTrainer:
self.plot_metrics() self.plot_metrics()
self.run_callbacks("on_train_end") self.run_callbacks("on_train_end")
gc.collect() gc.collect()
if MACOS: if MACOS and self.device.type == "mps":
torch.mps.empty_cache() torch.mps.empty_cache()
else: else:
torch.cuda.empty_cache() torch.cuda.empty_cache()