diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 8f01562a..0016c970 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -7,6 +7,7 @@ import random import time from contextlib import contextmanager from copy import deepcopy +from datetime import datetime from pathlib import Path from typing import Union @@ -456,14 +457,17 @@ def init_seeds(seed=0, deterministic=False): class ModelEMA: - """Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models - Keeps a moving average of everything in the model state_dict (parameters and buffers) + """ + Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving + average of everything in the model state_dict (parameters and buffers) + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + To disable EMA set the `enabled` attribute to `False`. """ def __init__(self, model, decay=0.9999, tau=2000, updates=0): - """Create EMA.""" + """Initialize EMA for 'model' with given arguments.""" self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA self.updates = updates # number of EMA updates self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) @@ -506,15 +510,25 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: from pathlib import Path from ultralytics.utils.torch_utils import strip_optimizer - for f in Path('path/to/weights').rglob('*.pt'): + for f in Path('path/to/model/checkpoints').rglob('*.pt'): strip_optimizer(f) ``` """ - x = torch.load(f, map_location=torch.device("cpu")) - if not isinstance(x, dict) or "model" not in x: - LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.") + try: + x = torch.load(f, map_location=torch.device("cpu")) + assert isinstance(x, dict), "checkpoint is not a Python dictionary" + assert "model" in x, "'model' missing from checkpoint" + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}") return + updates = { + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + # Update model if x.get("ema"): x["model"] = x["ema"] # replace model with EMA @@ -535,7 +549,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: # x['model'].args = x['train_args'] # Save - torch.save(x, s or f) + torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right) mb = os.path.getsize(s or f) / 1e6 # file size LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")