YOLOv5 updates (#90)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-12-25 14:33:18 +01:00 committed by GitHub
parent ebd3cfb2fd
commit 98815d560f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 281 additions and 161 deletions

View file

@ -60,7 +60,8 @@ class BaseTrainer:
# device
self.device = utils.torch_utils.select_device(self.args.device, self.batch_size)
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
self.amp = self.device.type != 'cpu'
self.scaler = amp.GradScaler(enabled=self.amp)
# Model and Dataloaders.
self.model = self.args.model
@ -175,6 +176,10 @@ class BaseTrainer:
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations
last_opt_step = -1
self.trigger_callbacks("on_train_start")
self.log(f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
f'Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n'
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for {self.epochs} epochs...")
for epoch in range(self.start_epoch, self.epochs):
self.epoch = epoch
self.trigger_callbacks("on_train_epoch_start")
@ -189,8 +194,6 @@ class BaseTrainer:
self.optimizer.zero_grad()
for i, batch in pbar:
self.trigger_callbacks("on_train_batch_start")
# forward
batch = self.preprocess_batch(batch)
# warmup
ni = i + nb * epoch
@ -204,17 +207,20 @@ class BaseTrainer:
if 'momentum' in x:
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
# Forward
with torch.cuda.amp.autocast(self.amp):
batch = self.preprocess_batch(batch)
preds = self.model(batch["img"])
self.loss, self.loss_items = self.criterion(preds, batch)
if rank != -1:
self.loss *= world_size
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \
else self.loss_items
# backward
# Backward
self.scaler.scale(self.loss).backward()
# optimize
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
@ -237,7 +243,7 @@ class BaseTrainer:
self.scheduler.step()
self.trigger_callbacks("on_train_epoch_end")
if rank in [-1, 0]:
if rank in {-1, 0}:
# validation
self.trigger_callbacks('on_val_start')
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
@ -245,7 +251,7 @@ class BaseTrainer:
if not self.args.noval or final_epoch:
self.metrics, self.fitness = self.validate()
self.trigger_callbacks('on_val_end')
log_vals = self.label_loss_items(self.tloss) | self.metrics | lr
log_vals = {**self.label_loss_items(self.tloss), **self.metrics, **lr}
self.save_metrics(metrics=log_vals)
# save model
@ -259,12 +265,13 @@ class BaseTrainer:
# TODO: termination condition
if rank in [-1, 0]:
if rank in {-1, 0}:
# do the last evaluation with best.pt
self.log(f'\n{epoch - self.start_epoch + 1} epochs completed in '
f'{(time.time() - self.train_time_start) / 3600:.3f} hours.')
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)")
self.log(f"Results saved to {colorstr('bold', self.save_dir)}")
self.trigger_callbacks('on_train_end')
dist.destroy_process_group() if world_size > 1 else None