Add EMA and model checkpointing (#49)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
27d6545117
commit
4291b9c31c
6 changed files with 55 additions and 21 deletions
|
|
@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo
|
|||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
|
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT
|
|||
from ultralytics.yolo.utils.checks import print_args
|
||||
from ultralytics.yolo.utils.files import increment_path, save_yaml
|
||||
from ultralytics.yolo.utils.modeling import get_model
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel
|
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||
|
||||
|
|
@ -63,6 +65,7 @@ class BaseTrainer:
|
|||
self.trainset, self.testset = self.get_dataset(self.data)
|
||||
if self.args.model:
|
||||
self.model = self.get_model(self.args.model)
|
||||
self.ema = None
|
||||
|
||||
# epoch level metrics
|
||||
self.metrics = {} # handle metrics returned by validator
|
||||
|
|
@ -144,6 +147,7 @@ class BaseTrainer:
|
|||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
self.console.info(self.progress_string())
|
||||
self.ema = ModelEMA(self.model)
|
||||
|
||||
def _do_train(self, rank=-1, world_size=1):
|
||||
if world_size > 1:
|
||||
|
|
@ -196,6 +200,7 @@ class BaseTrainer:
|
|||
if rank in [-1, 0]:
|
||||
# validation
|
||||
# callback: on_val_start()
|
||||
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights'])
|
||||
self.validate()
|
||||
# callback: on_val_end()
|
||||
|
||||
|
|
@ -220,10 +225,10 @@ class BaseTrainer:
|
|||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
|
||||
'ema': None, # deepcopy(ema.ema).half(),
|
||||
'updates': None, # ema.updates,
|
||||
'optimizer': None, # optimizer.state_dict(),
|
||||
'model': deepcopy(de_parallel(self.model)).half(),
|
||||
'ema': deepcopy(self.ema.ema).half(),
|
||||
'updates': self.ema.updates,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'train_args': self.args,
|
||||
'date': datetime.now().isoformat()}
|
||||
|
||||
|
|
@ -266,6 +271,8 @@ class BaseTrainer:
|
|||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
if self.ema:
|
||||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue