Default strip_optimizer() to use_dill=False (#14107)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-29 23:43:01 +02:00 committed by GitHub
parent e7ede6564d
commit ff63a56a42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ import random
import time
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Union
@ -456,14 +457,17 @@ def init_seeds(seed=0, deterministic=False):
class ModelEMA:
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
"""
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`.
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
"""Create EMA."""
"""Initialize EMA for 'model' with given arguments."""
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@ -506,15 +510,25 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
from pathlib import Path
from ultralytics.utils.torch_utils import strip_optimizer
for f in Path('path/to/weights').rglob('*.pt'):
for f in Path('path/to/model/checkpoints').rglob('*.pt'):
strip_optimizer(f)
```
"""
x = torch.load(f, map_location=torch.device("cpu"))
if not isinstance(x, dict) or "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
try:
x = torch.load(f, map_location=torch.device("cpu"))
assert isinstance(x, dict), "checkpoint is not a Python dictionary"
assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
return
updates = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
# Update model
if x.get("ema"):
x["model"] = x["ema"] # replace model with EMA
@ -535,7 +549,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args']
# Save
torch.save(x, s or f)
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")