ultralytics 8.0.174 Tuner plots and improvements (#4799)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-09-10 03:27:23 +02:00 committed by GitHub
parent dfe6dfb1d2
commit 16ce193d6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 248 additions and 88 deletions

View file

@ -423,7 +423,10 @@ class BaseTrainer:
self.run_callbacks('teardown')
def save_model(self):
"""Save model checkpoints based on various conditions."""
"""Save model training checkpoints with additional metadata."""
import pandas as pd # scope for faster startup
metrics = {**self.metrics, **{'fitness': self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.save_dir / 'results.csv').to_dict(orient='list').items()}
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
@ -432,22 +435,17 @@ class BaseTrainer:
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': vars(self.args), # save as dict
'train_metrics': metrics,
'train_results': results,
'date': datetime.now().isoformat(),
'version': __version__}
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
try:
import dill as pickle
except ImportError:
import pickle
# Save last, best and delete
torch.save(ckpt, self.last, pickle_module=pickle)
# Save last and best
torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best, pickle_module=pickle)
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
del ckpt
torch.save(ckpt, self.best)
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
@staticmethod
def get_dataset(data):
@ -654,6 +652,9 @@ class BaseTrainer:
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
if name == 'auto':
LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
nc = getattr(model, 'nc', 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)