Immediate checkpoint serialization (#9437)

This commit is contained in:
Glenn Jocher 2024-03-31 05:53:34 +02:00 committed by GitHub
parent eed01a2cf3
commit 479afce4a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -477,32 +477,37 @@ class BaseTrainer:
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import io
import pandas as pd # scope for faster startup
metrics = {**self.metrics, **{"fitness": self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
ckpt = {
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(),
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": metrics,
"train_results": results,
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO()
torch.save(
{
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(),
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
},
buffer,
)
serialized_ckpt = buffer.getvalue() # get the serialized content to save
# Save last and best
torch.save(ckpt, self.last)
# Save checkpoints
self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness:
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt
self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
@staticmethod
def get_dataset(data):