ultralytics 8.0.228 add training time argument (#7054)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b206b68ac6
commit
6cbe736bfd
11 changed files with 58 additions and 34 deletions
|
|
@ -189,6 +189,14 @@ class BaseTrainer:
|
|||
else:
|
||||
self._do_train(world_size)
|
||||
|
||||
def _setup_scheduler(self):
|
||||
"""Initialize training learning rate scheduler."""
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
|
|
@ -269,11 +277,7 @@ class BaseTrainer:
|
|||
decay=weight_decay,
|
||||
iterations=iterations)
|
||||
# Scheduler
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: (1 - x / self.epochs) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
self._setup_scheduler()
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
|
|
@ -285,17 +289,18 @@ class BaseTrainer:
|
|||
self._setup_ddp(world_size)
|
||||
self._setup_train(world_size)
|
||||
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||
last_opt_step = -1
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
self.run_callbacks('on_train_start')
|
||||
LOGGER.info(f'Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n'
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f'Starting training for {self.epochs} epochs...')
|
||||
f'Starting training for '
|
||||
f'{self.args.time} hours...' if self.args.time else f'{self.epochs} epochs...')
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
|
|
@ -323,7 +328,7 @@ class BaseTrainer:
|
|||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())
|
||||
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x['lr'] = np.interp(
|
||||
|
|
@ -348,6 +353,16 @@ class BaseTrainer:
|
|||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# Timed stopping
|
||||
if self.args.time:
|
||||
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop: # training time exceeded
|
||||
break
|
||||
|
||||
# Log
|
||||
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1
|
||||
|
|
@ -363,31 +378,37 @@ class BaseTrainer:
|
|||
self.run_callbacks('on_train_batch_end')
|
||||
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if RANK in (-1, 0):
|
||||
final_epoch = epoch + 1 == self.epochs
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
|
||||
# Validation
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
final_epoch = (epoch + 1 == self.epochs) or self.stopper.possible_stop
|
||||
|
||||
if self.args.val or final_epoch:
|
||||
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop = self.stopper(epoch + 1, self.fitness)
|
||||
self.stop |= self.stopper(epoch + 1, self.fitness)
|
||||
if self.args.time:
|
||||
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
|
||||
# Save model
|
||||
if self.args.save or (epoch + 1 == self.epochs):
|
||||
if self.args.save or final_epoch:
|
||||
self.save_model()
|
||||
self.run_callbacks('on_model_save')
|
||||
|
||||
tnow = time.time()
|
||||
self.epoch_time = tnow - self.epoch_time_start
|
||||
self.epoch_time_start = tnow
|
||||
# Scheduler
|
||||
t = time.time()
|
||||
self.epoch_time = t - self.epoch_time_start
|
||||
self.epoch_time_start = t
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
if self.args.time:
|
||||
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||
self._setup_scheduler()
|
||||
self.scheduler.last_epoch = self.epoch # do not move
|
||||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_fit_epoch_end')
|
||||
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
|
||||
|
||||
|
|
@ -395,8 +416,7 @@ class BaseTrainer:
|
|||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
if RANK != 0:
|
||||
self.stop = broadcast_list[0]
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue