ultralytics 8.2.95 faster checkpoint saving (#16311)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
7b19e0daa0
commit
ba438aea5a
5 changed files with 53 additions and 58 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.94"
|
||||
__version__ = "8.2.95"
|
||||
|
||||
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -668,13 +668,14 @@ class BaseTrainer:
|
|||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
ckpt = {}
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
if self.last.is_file(): # update best.pt train_metrics from last.pt
|
||||
k = "train_results"
|
||||
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
|
||||
if f is self.last:
|
||||
ckpt = strip_optimizer(f)
|
||||
elif f is self.best:
|
||||
k = "train_results" # update best.pt train_metrics from last.pt
|
||||
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
|
|
|
|||
|
|
@ -759,6 +759,10 @@ class SafeClass:
|
|||
"""Initialize SafeClass instance, ignoring all arguments."""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Run SafeClass instance, ignoring all arguments."""
|
||||
pass
|
||||
|
||||
|
||||
class SafeUnpickler(pickle.Unpickler):
|
||||
"""Custom Unpickler that replaces unknown classes with SafeClass."""
|
||||
|
|
|
|||
|
|
@ -533,16 +533,17 @@ class ModelEMA:
|
|||
copy_attr(self.ema, model, include, exclude)
|
||||
|
||||
|
||||
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
||||
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
|
||||
"""
|
||||
Strip optimizer from 'f' to finalize training, optionally save as 's'.
|
||||
|
||||
Args:
|
||||
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
|
||||
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
|
||||
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
|
||||
|
||||
Returns:
|
||||
None
|
||||
(dict): The combined checkpoint dictionary.
|
||||
|
||||
Example:
|
||||
```python
|
||||
|
|
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
|||
assert "model" in x, "'model' missing from checkpoint"
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
|
||||
return
|
||||
return {}
|
||||
|
||||
updates = {
|
||||
metadata = {
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||
|
|
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
|
|||
# x['model'].args = x['train_args']
|
||||
|
||||
# Save
|
||||
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
|
||||
combined = {**metadata, **x, **(updates or {})}
|
||||
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
|
||||
mb = os.path.getsize(s or f) / 1e6 # file size
|
||||
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
||||
return combined
|
||||
|
||||
|
||||
def convert_optimizer_state_dict_to_fp16(state_dict):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue