From 1703025e8eba9689132ae042d51551942b6d5856 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 31 Mar 2024 15:11:42 +0200 Subject: [PATCH] Save optimizer as FP16 for smaller checkpoints (#9435) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- docs/en/reference/utils/torch_utils.md | 4 ++++ ultralytics/engine/trainer.py | 3 ++- ultralytics/utils/torch_utils.py | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/docs/en/reference/utils/torch_utils.md b/docs/en/reference/utils/torch_utils.md index 5c88f293..971a1072 100644 --- a/docs/en/reference/utils/torch_utils.md +++ b/docs/en/reference/utils/torch_utils.md @@ -115,6 +115,10 @@ keywords: Ultralytics, Torch Utils, Model EMA, Early Stopping, Smart Inference,

+## ::: ultralytics.utils.torch_utils.convert_optimizer_state_dict_to_fp16 + +

+ ## ::: ultralytics.utils.torch_utils.profile

diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 7a963a38..8b5e47cc 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -42,6 +42,7 @@ from ultralytics.utils.files import get_latest_run from ultralytics.utils.torch_utils import ( EarlyStopping, ModelEMA, + convert_optimizer_state_dict_to_fp16, init_seeds, one_cycle, select_device, @@ -488,7 +489,7 @@ class BaseTrainer: "model": None, # resume and final checkpoints derive from EMA "ema": deepcopy(self.ema.ema).half(), "updates": self.ema.updates, - "optimizer": self.optimizer.state_dict(), + "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), "train_args": vars(self.args), # save as dict "train_metrics": {**self.metrics, **{"fitness": self.fitness}}, "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}, diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 77d8cc8c..77449b04 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -505,6 +505,20 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None: LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") +def convert_optimizer_state_dict_to_fp16(state_dict): + """ + Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions. + + This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data. + """ + for state in state_dict["state"].values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor) and v.dtype is torch.float32: + state[k] = v.half() + + return state_dict + + def profile(input, ops, n=10, device=None): """ Ultralytics speed, memory and FLOPs profiler.