ultralytics 8.3.31 add max_num_obj factor for AutoBatch (#17514)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e100484422
commit
4453ddab93
6 changed files with 38 additions and 12 deletions
|
|
@ -279,12 +279,7 @@ class BaseTrainer:
|
|||
|
||||
# Batch size
|
||||
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(
|
||||
model=self.model,
|
||||
imgsz=self.args.imgsz,
|
||||
amp=self.amp,
|
||||
batch=self.batch_size,
|
||||
)
|
||||
self.args.batch = self.batch_size = self.auto_batch()
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
|
|
@ -478,6 +473,16 @@ class BaseTrainer:
|
|||
self._clear_memory()
|
||||
self.run_callbacks("teardown")
|
||||
|
||||
def auto_batch(self, max_num_obj=0):
|
||||
"""Get batch size by calculating memory occupation of model."""
|
||||
return check_train_batch_size(
|
||||
model=self.model,
|
||||
imgsz=self.args.imgsz,
|
||||
amp=self.amp,
|
||||
batch=self.batch_size,
|
||||
max_num_obj=max_num_obj,
|
||||
) # returns batch size
|
||||
|
||||
def _get_memory(self):
|
||||
"""Get accelerator memory utilization in GB."""
|
||||
if self.device.type == "mps":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue