YOLOv5 updates (#90)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ebd3cfb2fd
commit
98815d560f
27 changed files with 281 additions and 161 deletions
|
|
@ -60,7 +60,8 @@ class BaseTrainer:
|
|||
|
||||
# device
|
||||
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
self.amp = self.device.type != 'cpu'
|
||||
self.scaler = amp.GradScaler(enabled=self.amp)
|
||||
|
||||
# Model and Dataloaders.
|
||||
self.model = self.args.model
|
||||
|
|
@ -175,6 +176,10 @@ class BaseTrainer:
|
|||
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
|
||||
last_opt_step = -1
|
||||
self.trigger_callbacks("on_train_start")
|
||||
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f"Starting training for {self.epochs} epochs...")
|
||||
for epoch in range(self.start_epoch, self.epochs):
|
||||
self.epoch = epoch
|
||||
self.trigger_callbacks("on_train_epoch_start")
|
||||
|
|
@ -189,8 +194,6 @@ class BaseTrainer:
|
|||
self.optimizer.zero_grad()
|
||||
for i, batch in pbar:
|
||||
self.trigger_callbacks("on_train_batch_start")
|
||||
# forward
|
||||
batch = self.preprocess_batch(batch)
|
||||
|
||||
# warmup
|
||||
ni = i + nb * epoch
|
||||
|
|
@ -204,17 +207,20 @@ class BaseTrainer:
|
|||
if 'momentum' in x:
|
||||
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
preds = self.model(batch["img"])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
if rank != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
# Forward
|
||||
with torch.cuda.amp.autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
preds = self.model(batch["img"])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
if rank != -1:
|
||||
self.loss *= world_size
|
||||
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
|
||||
else self.loss_items
|
||||
|
||||
# backward
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# optimize
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
|
@ -237,7 +243,7 @@ class BaseTrainer:
|
|||
self.scheduler.step()
|
||||
self.trigger_callbacks("on_train_epoch_end")
|
||||
|
||||
if rank in [-1, 0]:
|
||||
if rank in {-1, 0}:
|
||||
# validation
|
||||
self.trigger_callbacks('on_val_start')
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
|
|
@ -245,7 +251,7 @@ class BaseTrainer:
|
|||
if not self.args.noval or final_epoch:
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.trigger_callbacks('on_val_end')
|
||||
log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
|
||||
log_vals = {**self.label_loss_items(self.tloss), **self.metrics, **lr}
|
||||
self.save_metrics(metrics=log_vals)
|
||||
|
||||
# save model
|
||||
|
|
@ -259,12 +265,13 @@ class BaseTrainer:
|
|||
|
||||
# TODO: termination condition
|
||||
|
||||
if rank in [-1, 0]:
|
||||
if rank in {-1, 0}:
|
||||
# do the last evaluation with best.pt
|
||||
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
|
||||
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
|
||||
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
self.trigger_callbacks('on_train_end')
|
||||
dist.destroy_process_group() if world_size > 1 else None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue