ultralytics 8.2.29 new fractional AutoBatch feature (#13446)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-09 00:28:11 +02:00 committed by GitHub
parent 2fe0946376
commit 6a234f3639
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 92 additions and 49 deletions

View file

@ -10,9 +10,9 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640, amp=True):
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
"""
Check YOLO training batch size using the autobatch() function.
Compute optimal YOLO training batch size using the autobatch() function.
Args:
model (torch.nn.Module): YOLO model to check batch size for.
@ -24,7 +24,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
"""
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
@ -43,7 +43,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
# Check device
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
device = next(model.parameters()).device # get model device
if device.type == "cpu":
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")

View file

@ -146,11 +146,17 @@ def select_device(device="", batch=0, newline=False, verbose=True):
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
if n > 1: # multi-GPU
if batch < 1:
raise ValueError(
"AutoBatch with batch<1 not supported for Multi-GPU training, "
"please specify a valid batch size, i.e. batch=16."
)
if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)