Default strip_optimizer() to use_dill=False (#14107)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e7ede6564d
commit
ff63a56a42
1 changed files with 22 additions and 8 deletions
|
|
@ -7,6 +7,7 @@ import random
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
|
@ -456,14 +457,17 @@ def init_seeds(seed=0, deterministic=False):
|
|||
|
||||
|
||||
class ModelEMA:
|
||||
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
"""
|
||||
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
|
||||
average of everything in the model state_dict (parameters and buffers)
|
||||
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
|
||||
To disable EMA set the `enabled` attribute to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
"""Create EMA."""
|
||||
"""Initialize EMA for 'model' with given arguments."""
|
||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||
|
|
@ -506,15 +510,25 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
|||
from pathlib import Path
|
||||
from ultralytics.utils.torch_utils import strip_optimizer
|
||||
|
||||
for f in Path('path/to/weights').rglob('*.pt'):
|
||||
for f in Path('path/to/model/checkpoints').rglob('*.pt'):
|
||||
strip_optimizer(f)
|
||||
```
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
if not isinstance(x, dict) or "model" not in x:
|
||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||
try:
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
assert isinstance(x, dict), "checkpoint is not a Python dictionary"
|
||||
assert "model" in x, "'model' missing from checkpoint"
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
||||
return
|
||||
|
||||
updates = {
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||
"docs": "https://docs.ultralytics.com",
|
||||
}
|
||||
|
||||
# Update model
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with EMA
|
||||
|
|
@ -535,7 +549,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
|||
# x['model'].args = x['train_args']
|
||||
|
||||
# Save
|
||||
torch.save(x, s or f)
|
||||
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue