MPS unified memory cache empty (#16078)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
ccd2937aa1
commit
4d5afa7e0d
1 changed files with 10 additions and 2 deletions
|
|
@ -28,6 +28,7 @@ from ultralytics.utils import (
|
|||
DEFAULT_CFG,
|
||||
LOCAL_RANK,
|
||||
LOGGER,
|
||||
MACOS,
|
||||
RANK,
|
||||
TQDM,
|
||||
__version__,
|
||||
|
|
@ -453,6 +454,9 @@ class BaseTrainer:
|
|||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
gc.collect()
|
||||
if MACOS:
|
||||
torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy
|
||||
else:
|
||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||
|
||||
# Early Stopping
|
||||
|
|
@ -475,7 +479,11 @@ class BaseTrainer:
|
|||
self.plot_metrics()
|
||||
self.run_callbacks("on_train_end")
|
||||
gc.collect()
|
||||
if MACOS:
|
||||
torch.mps.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.run_callbacks("teardown")
|
||||
|
||||
def read_results_csv(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue