ultralytics 8.2.94 Apple MPS train memory display (#16272)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Quet Almahdi Morris 2024-09-15 12:34:31 -05:00 committed by GitHub
parent 87296e9e75
commit fa6362a6f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 38 additions and 26 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.93"
__version__ = "8.2.94"
import os

View file

@ -28,7 +28,6 @@ from ultralytics.utils import (
DEFAULT_CFG,
LOCAL_RANK,
LOGGER,
MACOS,
RANK,
TQDM,
__version__,
@ -409,13 +408,17 @@ class BaseTrainer:
break
# Log
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
if RANK in {-1, 0}:
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
pbar.set_description(
("%11s" * 2 + "%11.4g" * (2 + loss_len))
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
("%11s" * 2 + "%11.4g" * (2 + loss_length))
% (
f"{epoch + 1}/{self.epochs}",
f"{self._get_memory():.3g}G", # (GB) GPU memory util
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
batch["cls"].shape[0], # batch size, i.e. 8
batch["img"].shape[-1], # imgsz, i.e 640
)
)
self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
@ -453,11 +456,7 @@ class BaseTrainer:
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end")
gc.collect()
if MACOS and self.device.type == "mps":
torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy
else:
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
self._clear_memory()
# Early Stopping
if RANK != -1: # if DDP training
@ -478,14 +477,29 @@ class BaseTrainer:
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
self._clear_memory()
self.run_callbacks("teardown")
def _get_memory(self):
"""Get accelerator memory utilization in GB."""
if self.device.type == "mps":
memory = torch.mps.driver_allocated_memory()
elif self.device.type == "cpu":
memory = 0
else:
memory = torch.cuda.memory_reserved()
return memory / 1e9
def _clear_memory(self):
"""Clear accelerator memory on different platforms."""
gc.collect()
if MACOS and self.device.type == "mps":
if self.device.type == "mps":
torch.mps.empty_cache()
elif self.device.type == "cpu":
return
else:
torch.cuda.empty_cache()
self.run_callbacks("teardown")
def read_results_csv(self):
"""Read results.csv into a dict using pandas."""
import pandas as pd # scope for faster 'import ultralytics'