Add EMA and model checkpointing (#49)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2022-11-19 23:37:26 +05:30 committed by GitHub
parent 27d6545117
commit 4291b9c31c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 21 deletions

View file

@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
import os
import time
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Dict, Union
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
from ultralytics.yolo.utils.checks import print_args
from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
@ -63,6 +65,7 @@ class BaseTrainer:
self.trainset, self.testset = self.get_dataset(self.data)
if self.args.model:
self.model = self.get_model(self.args.model)
self.ema = None
# epoch level metrics
self.metrics = {} # handle metrics returned by validator
@ -144,6 +147,7 @@ class BaseTrainer:
self.validator = self.get_validator()
print("created testloader :", rank)
self.console.info(self.progress_string())
self.ema = ModelEMA(self.model)
def _do_train(self, rank=-1, world_size=1):
if world_size > 1:
@ -196,6 +200,7 @@ class BaseTrainer:
if rank in [-1, 0]:
# validation
# callback: on_val_start()
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
self.validate()
# callback: on_val_end()
@ -220,10 +225,10 @@ class BaseTrainer:
ckpt = {
'epoch': self.epoch,
'best_fitness': self.best_fitness,
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
'ema': None, # deepcopy(ema.ema).half(),
'updates': None, # ema.updates,
'optimizer': None, # optimizer.state_dict(),
'model': deepcopy(de_parallel(self.model)).half(),
'ema': deepcopy(self.ema.ema).half(),
'updates': self.ema.updates,
'optimizer': self.optimizer.state_dict(),
'train_args': self.args,
'date': datetime.now().isoformat()}
@ -266,6 +271,8 @@ class BaseTrainer:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""