Immediate checkpoint serialization (#9437)

This commit is contained in:
Glenn Jocher 2024-03-31 05:53:34 +02:00 committed by GitHub
parent eed01a2cf3
commit 479afce4a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -477,11 +477,13 @@ class BaseTrainer:
def save_model(self): def save_model(self):
"""Save model training checkpoints with additional metadata.""" """Save model training checkpoints with additional metadata."""
import io
import pandas as pd # scope for faster startup import pandas as pd # scope for faster startup
metrics = {**self.metrics, **{"fitness": self.fitness}} # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} buffer = io.BytesIO()
ckpt = { torch.save(
{
"epoch": self.epoch, "epoch": self.epoch,
"best_fitness": self.best_fitness, "best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(), "model": deepcopy(de_parallel(self.model)).half(),
@ -489,20 +491,23 @@ class BaseTrainer:
"updates": self.ema.updates, "updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(), "optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict "train_args": vars(self.args), # save as dict
"train_metrics": metrics, "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": results, "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)", "license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com", "docs": "https://docs.ultralytics.com",
} },
buffer,
)
serialized_ckpt = buffer.getvalue() # get the serialized content to save
# Save last and best # Save checkpoints
torch.save(ckpt, self.last) self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness: if self.best_fitness == self.fitness:
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0): if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
@staticmethod @staticmethod
def get_dataset(data): def get_dataset(data):