ultralytics 8.2.48 strip model criterion on save (#14106)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-29 21:50:58 +02:00 committed by GitHub
parent 7c1999929a
commit e7ede6564d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 18 additions and 11 deletions

View file

@ -511,23 +511,30 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
```
"""
x = torch.load(f, map_location=torch.device("cpu"))
if "model" not in x:
if not isinstance(x, dict) or "model" not in x:
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
return
# Update model
if x.get("ema"):
x["model"] = x["ema"] # replace model with EMA
if hasattr(x["model"], "args"):
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
if x.get("ema"):
x["model"] = x["ema"] # replace model with ema
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x["epoch"] = -1
if hasattr(x["model"], "criterion"):
x["model"].criterion = None # strip loss criterion
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
# Update other keys
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
for k in "optimizer", "best_fitness", "ema", "updates": # keys
x[k] = None
x["epoch"] = -1
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
# x['model'].args = x['train_args']
# Save
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")