ultralytics 8.0.206 engine Trainer updates (#6111)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jamjamjon <51357717+jamjamjon@users.noreply.github.com>
This commit is contained in:
parent
25bd3b9834
commit
f2f5ed2c5e
7 changed files with 42 additions and 34 deletions
|
|
@ -19,8 +19,6 @@ import numpy as np
|
|||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn, optim
|
||||
from torch.cuda import amp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
|
|
@ -28,7 +26,7 @@ from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
|||
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM, __version__, callbacks, clean_url, colorstr, emojis,
|
||||
yaml_save)
|
||||
from ultralytics.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run
|
||||
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
|
||||
|
|
@ -43,7 +41,6 @@ class BaseTrainer:
|
|||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
|
|
@ -62,6 +59,7 @@ class BaseTrainer:
|
|||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
resume (bool): Resume training from a checkpoint.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
|
|
@ -84,7 +82,6 @@ class BaseTrainer:
|
|||
self.check_resume(overrides)
|
||||
self.device = select_device(self.args.device, self.args.batch)
|
||||
self.validator = None
|
||||
self.model = None
|
||||
self.metrics = None
|
||||
self.plots = {}
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
|
@ -111,7 +108,7 @@ class BaseTrainer:
|
|||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
self.model = self.args.model
|
||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if self.args.task == 'classify':
|
||||
self.data = check_cls_dataset(self.args.data)
|
||||
|
|
@ -124,6 +121,7 @@ class BaseTrainer:
|
|||
|
||||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
self.ema = None
|
||||
self.resume = False
|
||||
|
||||
# Optimization utils init
|
||||
self.lf = None
|
||||
|
|
@ -236,9 +234,9 @@ class BaseTrainer:
|
|||
if RANK > -1 and world_size > 1: # DDP
|
||||
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[RANK])
|
||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK])
|
||||
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
|
|
@ -311,11 +309,7 @@ class BaseTrainer:
|
|||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
|
|
@ -395,7 +389,7 @@ class BaseTrainer:
|
|||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
torch.cuda.empty_cache() # clears GPU vRAM at end of epoch, can help with out of memory errors
|
||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
|
|
@ -613,11 +607,15 @@ class BaseTrainer:
|
|||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self._close_dataloader_mosaic()
|
||||
|
||||
def _close_dataloader_mosaic(self):
|
||||
"""Update dataloaders to stop using mosaic augmentation."""
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
LOGGER.info('Closing dataloader mosaic')
|
||||
if hasattr(self.train_loader.dataset, 'mosaic'):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, 'close_mosaic'):
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
self.train_loader.dataset.close_mosaic(hyp=self.args)
|
||||
|
||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue