ultralytics 8.3.31 add max_num_obj factor for AutoBatch (#17514)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-11-14 06:51:24 +08:00 committed by GitHub
parent e100484422
commit 4453ddab93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 38 additions and 12 deletions

View file

@ -279,12 +279,7 @@ class BaseTrainer:
# Batch size
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
)
self.args.batch = self.batch_size = self.auto_batch()
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)
@ -478,6 +473,16 @@ class BaseTrainer:
self._clear_memory()
self.run_callbacks("teardown")
def auto_batch(self, max_num_obj=0):
"""Get batch size by calculating memory occupation of model."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
) # returns batch size
def _get_memory(self):
"""Get accelerator memory utilization in GB."""
if self.device.type == "mps":