ultralytics 8.0.206 engine Trainer updates (#6111)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: jamjamjon <51357717+jamjamjon@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-11-04 02:57:35 +01:00 committed by GitHub
parent 25bd3b9834
commit f2f5ed2c5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 42 additions and 34 deletions

View file

@ -19,8 +19,6 @@ import numpy as np
import torch
from torch import distributed as dist
from torch import nn, optim
from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
@ -28,7 +26,7 @@ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
@ -43,7 +41,6 @@ class BaseTrainer:
Attributes:
args (SimpleNamespace): Configuration for the trainer.
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
@ -62,6 +59,7 @@ class BaseTrainer:
trainset (torch.utils.data.Dataset): Training dataset.
testset (torch.utils.data.Dataset): Testing dataset.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
resume (bool): Resume training from a checkpoint.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
@ -84,7 +82,6 @@ class BaseTrainer:
self.check_resume(overrides)
self.device = select_device(self.args.device, self.args.batch)
self.validator = None
self.model = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
@ -111,7 +108,7 @@ class BaseTrainer:
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = self.args.model
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
try:
if self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
@ -124,6 +121,7 @@ class BaseTrainer:
self.trainset, self.testset = self.get_dataset(self.data)
self.ema = None
self.resume = False
# Optimization utils init
self.lf = None
@ -236,9 +234,9 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean
self.scaler = amp.GradScaler(enabled=self.amp)
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
if world_size > 1:
self.model = DDP(self.model, device_ids=[RANK])
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
@ -311,11 +309,7 @@ class BaseTrainer:
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self._close_dataloader_mosaic()
self.train_loader.reset()
if RANK in (-1, 0):
@ -395,7 +389,7 @@ class BaseTrainer:
self.epoch_time = tnow - self.epoch_time_start
self.epoch_time_start = tnow
self.run_callbacks('on_fit_epoch_end')
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
# Early Stopping
if RANK != -1: # if DDP training
@ -613,11 +607,15 @@ class BaseTrainer:
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
def _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
LOGGER.info('Closing dataloader mosaic')
if hasattr(self.train_loader.dataset, 'mosaic'):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, 'close_mosaic'):
self.train_loader.dataset.close_mosaic(hyp=self.args)
self.train_loader.dataset.close_mosaic(hyp=self.args)
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""