Default strip_optimizer() to use_dill=False (#14107)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-29 23:43:01 +02:00 committed by GitHub
parent e7ede6564d
commit ff63a56a42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -7,6 +7,7 @@ import random
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -456,14 +457,17 @@ def init_seeds(seed=0, deterministic=False):
class ModelEMA: class ModelEMA:
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models """
Keeps a moving average of everything in the model state_dict (parameters and buffers) Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
To disable EMA set the `enabled` attribute to `False`. To disable EMA set the `enabled` attribute to `False`.
""" """
def __init__(self, model, decay=0.9999, tau=2000, updates=0): def __init__(self, model, decay=0.9999, tau=2000, updates=0):
"""Create EMA.""" """Initialize EMA for 'model' with given arguments."""
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
self.updates = updates # number of EMA updates self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
@ -506,15 +510,25 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
from pathlib import Path from pathlib import Path
from ultralytics.utils.torch_utils import strip_optimizer from ultralytics.utils.torch_utils import strip_optimizer
for f in Path('path/to/weights').rglob('*.pt'): for f in Path('path/to/model/checkpoints').rglob('*.pt'):
strip_optimizer(f) strip_optimizer(f)
``` ```
""" """
x = torch.load(f, map_location=torch.device("cpu")) try:
if not isinstance(x, dict) or "model" not in x: x = torch.load(f, map_location=torch.device("cpu"))
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.") assert isinstance(x, dict), "checkpoint is not a Python dictionary"
assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
return return
updates = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
# Update model # Update model
if x.get("ema"): if x.get("ema"):
x["model"] = x["ema"] # replace model with EMA x["model"] = x["ema"] # replace model with EMA
@ -535,7 +549,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args'] # x['model'].args = x['train_args']
# Save # Save
torch.save(x, s or f) torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")