ultralytics 8.0.174 Tuner plots and improvements (#4799)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dfe6dfb1d2
commit
16ce193d6e
12 changed files with 248 additions and 88 deletions
|
|
@ -423,7 +423,10 @@ class BaseTrainer:
|
|||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
"""Save model training checkpoints with additional metadata."""
|
||||
import pandas as pd # scope for faster startup
|
||||
metrics = {**self.metrics, **{'fitness': self.fitness}}
|
||||
results = {k.strip(): v for k, v in pd.read_csv(self.save_dir / 'results.csv').to_dict(orient='list').items()}
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
|
|
@ -432,22 +435,17 @@ class BaseTrainer:
|
|||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': vars(self.args), # save as dict
|
||||
'train_metrics': metrics,
|
||||
'train_results': results,
|
||||
'date': datetime.now().isoformat(),
|
||||
'version': __version__}
|
||||
|
||||
# Use dill (if exists) to serialize the lambda functions where pickle does not do this
|
||||
try:
|
||||
import dill as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
# Save last, best and delete
|
||||
torch.save(ckpt, self.last, pickle_module=pickle)
|
||||
# Save last and best
|
||||
torch.save(ckpt, self.last)
|
||||
if self.best_fitness == self.fitness:
|
||||
torch.save(ckpt, self.best, pickle_module=pickle)
|
||||
if (self.epoch > 0) and (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt', pickle_module=pickle)
|
||||
del ckpt
|
||||
torch.save(ckpt, self.best)
|
||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||
torch.save(ckpt, self.wdir / f'epoch{self.epoch}.pt')
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
|
|
@ -654,6 +652,9 @@ class BaseTrainer:
|
|||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == 'auto':
|
||||
LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
|
||||
nc = getattr(model, 'nc', 10) # number of classes
|
||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||
name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue