Default strip_optimizer() to use_dill=False (#14107)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e7ede6564d
commit
ff63a56a42
1 changed files with 22 additions and 8 deletions
|
|
@ -7,6 +7,7 @@ import random
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
|
@ -456,14 +457,17 @@ def init_seeds(seed=0, deterministic=False):
|
||||||
|
|
||||||
|
|
||||||
class ModelEMA:
|
class ModelEMA:
|
||||||
"""Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
"""
|
||||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving
|
||||||
|
average of everything in the model state_dict (parameters and buffers)
|
||||||
|
|
||||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
|
||||||
To disable EMA set the `enabled` attribute to `False`.
|
To disable EMA set the `enabled` attribute to `False`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||||
"""Create EMA."""
|
"""Initialize EMA for 'model' with given arguments."""
|
||||||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||||||
self.updates = updates # number of EMA updates
|
self.updates = updates # number of EMA updates
|
||||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||||||
|
|
@ -506,15 +510,25 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from ultralytics.utils.torch_utils import strip_optimizer
|
from ultralytics.utils.torch_utils import strip_optimizer
|
||||||
|
|
||||||
for f in Path('path/to/weights').rglob('*.pt'):
|
for f in Path('path/to/model/checkpoints').rglob('*.pt'):
|
||||||
strip_optimizer(f)
|
strip_optimizer(f)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
x = torch.load(f, map_location=torch.device("cpu"))
|
try:
|
||||||
if not isinstance(x, dict) or "model" not in x:
|
x = torch.load(f, map_location=torch.device("cpu"))
|
||||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
assert isinstance(x, dict), "checkpoint is not a Python dictionary"
|
||||||
|
assert "model" in x, "'model' missing from checkpoint"
|
||||||
|
except Exception as e:
|
||||||
|
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
updates = {
|
||||||
|
"date": datetime.now().isoformat(),
|
||||||
|
"version": __version__,
|
||||||
|
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||||
|
"docs": "https://docs.ultralytics.com",
|
||||||
|
}
|
||||||
|
|
||||||
# Update model
|
# Update model
|
||||||
if x.get("ema"):
|
if x.get("ema"):
|
||||||
x["model"] = x["ema"] # replace model with EMA
|
x["model"] = x["ema"] # replace model with EMA
|
||||||
|
|
@ -535,7 +549,7 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||||
# x['model'].args = x['train_args']
|
# x['model'].args = x['train_args']
|
||||||
|
|
||||||
# Save
|
# Save
|
||||||
torch.save(x, s or f)
|
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
|
||||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue