Favor EMA over model in train checkpoints (#9433)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-31 06:14:22 +02:00 committed by GitHub
parent 479afce4a8
commit 7ea2007326
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -42,7 +42,6 @@ from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import ( from ultralytics.utils.torch_utils import (
EarlyStopping, EarlyStopping,
ModelEMA, ModelEMA,
de_parallel,
init_seeds, init_seeds,
one_cycle, one_cycle,
select_device, select_device,
@ -486,7 +485,7 @@ class BaseTrainer:
{ {
"epoch": self.epoch, "epoch": self.epoch,
"best_fitness": self.best_fitness, "best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(), "model": None, # resume and final checkpoints derive from EMA
"ema": deepcopy(self.ema.ema).half(), "ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates, "updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(), "optimizer": self.optimizer.state_dict(),
@ -527,7 +526,7 @@ class BaseTrainer:
ckpt = None ckpt = None
if str(model).endswith(".pt"): if str(model).endswith(".pt"):
weights, ckpt = attempt_load_one_weight(model) weights, ckpt = attempt_load_one_weight(model)
cfg = ckpt["model"].yaml cfg = weights.yaml
else: else:
cfg = model cfg = model
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)