ultralytics 8.2.95 faster checkpoint saving (#16311)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-09-16 21:37:14 +02:00 committed by GitHub
parent 7b19e0daa0
commit ba438aea5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 53 additions and 58 deletions

View file

@ -533,16 +533,17 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args:
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
Returns:
None
(dict): The combined checkpoint dictionary.
Example:
```python
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
return
return {}
updates = {
metadata = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args']
# Save
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
combined = {**metadata, **x, **(updates or {})}
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
return combined
def convert_optimizer_state_dict_to_fp16(state_dict):