ultralytics 8.2.48 strip model criterion on save (#14106)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
7c1999929a
commit
e7ede6564d
3 changed files with 18 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.47"
|
||||
__version__ = "8.2.48"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -788,14 +788,14 @@ def torch_safe_load(weight):
|
|||
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
||||
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
)
|
||||
) from e
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
||||
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
)
|
||||
check_requirements(e.name) # install missing module
|
||||
ckpt = torch.load(file, map_location="cpu")
|
||||
|
|
|
|||
|
|
@ -511,23 +511,30 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
|||
```
|
||||
"""
|
||||
x = torch.load(f, map_location=torch.device("cpu"))
|
||||
if "model" not in x:
|
||||
if not isinstance(x, dict) or "model" not in x:
|
||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||
return
|
||||
|
||||
# Update model
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with EMA
|
||||
if hasattr(x["model"], "args"):
|
||||
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
||||
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
|
||||
if x.get("ema"):
|
||||
x["model"] = x["ema"] # replace model with ema
|
||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||
x[k] = None
|
||||
x["epoch"] = -1
|
||||
if hasattr(x["model"], "criterion"):
|
||||
x["model"].criterion = None # strip loss criterion
|
||||
x["model"].half() # to FP16
|
||||
for p in x["model"].parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# Update other keys
|
||||
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
|
||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||
x[k] = None
|
||||
x["epoch"] = -1
|
||||
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||
# x['model'].args = x['train_args']
|
||||
|
||||
# Save
|
||||
torch.save(x, s or f)
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue