ultralytics 8.2.48 strip model criterion on save (#14106)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
7c1999929a
commit
e7ede6564d
3 changed files with 18 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.47"
|
__version__ = "8.2.48"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -788,14 +788,14 @@ def torch_safe_load(weight):
|
||||||
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
||||||
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
||||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||||
)
|
)
|
||||||
) from e
|
) from e
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements."
|
||||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||||
)
|
)
|
||||||
check_requirements(e.name) # install missing module
|
check_requirements(e.name) # install missing module
|
||||||
ckpt = torch.load(file, map_location="cpu")
|
ckpt = torch.load(file, map_location="cpu")
|
||||||
|
|
|
||||||
|
|
@ -511,23 +511,30 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
x = torch.load(f, map_location=torch.device("cpu"))
|
x = torch.load(f, map_location=torch.device("cpu"))
|
||||||
if "model" not in x:
|
if not isinstance(x, dict) or "model" not in x:
|
||||||
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
LOGGER.info(f"Skipping {f}, not a valid Ultralytics model.")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Update model
|
||||||
|
if x.get("ema"):
|
||||||
|
x["model"] = x["ema"] # replace model with EMA
|
||||||
if hasattr(x["model"], "args"):
|
if hasattr(x["model"], "args"):
|
||||||
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
|
||||||
args = {**DEFAULT_CFG_DICT, **x["train_args"]} if "train_args" in x else None # combine args
|
if hasattr(x["model"], "criterion"):
|
||||||
if x.get("ema"):
|
x["model"].criterion = None # strip loss criterion
|
||||||
x["model"] = x["ema"] # replace model with ema
|
|
||||||
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
|
||||||
x[k] = None
|
|
||||||
x["epoch"] = -1
|
|
||||||
x["model"].half() # to FP16
|
x["model"].half() # to FP16
|
||||||
for p in x["model"].parameters():
|
for p in x["model"].parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
|
# Update other keys
|
||||||
|
args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
|
||||||
|
for k in "optimizer", "best_fitness", "ema", "updates": # keys
|
||||||
|
x[k] = None
|
||||||
|
x["epoch"] = -1
|
||||||
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
|
||||||
# x['model'].args = x['train_args']
|
# x['model'].args = x['train_args']
|
||||||
|
|
||||||
|
# Save
|
||||||
torch.save(x, s or f)
|
torch.save(x, s or f)
|
||||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue