From 479afce4a855ab29ea39ec083cfb7edd472d07cf Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 31 Mar 2024 05:53:34 +0200 Subject: [PATCH] Immediate checkpoint serialization (#9437) --- ultralytics/engine/trainer.py | 47 +++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 8a993c25..2dcbf9fe 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -477,32 +477,37 @@ class BaseTrainer: def save_model(self): """Save model training checkpoints with additional metadata.""" + import io import pandas as pd # scope for faster startup - metrics = {**self.metrics, **{"fitness": self.fitness}} - results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} - ckpt = { - "epoch": self.epoch, - "best_fitness": self.best_fitness, - "model": deepcopy(de_parallel(self.model)).half(), - "ema": deepcopy(self.ema.ema).half(), - "updates": self.ema.updates, - "optimizer": self.optimizer.state_dict(), - "train_args": vars(self.args), # save as dict - "train_metrics": metrics, - "train_results": results, - "date": datetime.now().isoformat(), - "version": __version__, - "license": "AGPL-3.0 (https://ultralytics.com/license)", - "docs": "https://docs.ultralytics.com", - } + # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) + buffer = io.BytesIO() + torch.save( + { + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": deepcopy(de_parallel(self.model)).half(), + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": self.optimizer.state_dict(), + "train_args": vars(self.args), # save as dict + "train_metrics": {**self.metrics, **{"fitness": self.fitness}}, + "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + }, + buffer, + ) + serialized_ckpt = buffer.getvalue() # get the serialized content to save - # Save last and best - torch.save(ckpt, self.last) + # Save checkpoints + self.last.write_bytes(serialized_ckpt) # save last.pt if self.best_fitness == self.fitness: - self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt + self.best.write_bytes(serialized_ckpt) # save best.pt if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): - (self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt + (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' @staticmethod def get_dataset(data):