Immediate checkpoint serialization (#9437)
This commit is contained in:
parent
eed01a2cf3
commit
479afce4a8
1 changed files with 26 additions and 21 deletions
|
|
@ -477,32 +477,37 @@ class BaseTrainer:
|
|||
|
||||
def save_model(self):
|
||||
"""Save model training checkpoints with additional metadata."""
|
||||
import io
|
||||
import pandas as pd # scope for faster startup
|
||||
|
||||
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
||||
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
||||
ckpt = {
|
||||
"epoch": self.epoch,
|
||||
"best_fitness": self.best_fitness,
|
||||
"model": deepcopy(de_parallel(self.model)).half(),
|
||||
"ema": deepcopy(self.ema.ema).half(),
|
||||
"updates": self.ema.updates,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"train_args": vars(self.args), # save as dict
|
||||
"train_metrics": metrics,
|
||||
"train_results": results,
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||
"docs": "https://docs.ultralytics.com",
|
||||
}
|
||||
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(
|
||||
{
|
||||
"epoch": self.epoch,
|
||||
"best_fitness": self.best_fitness,
|
||||
"model": deepcopy(de_parallel(self.model)).half(),
|
||||
"ema": deepcopy(self.ema.ema).half(),
|
||||
"updates": self.ema.updates,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"train_args": vars(self.args), # save as dict
|
||||
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
||||
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||
"docs": "https://docs.ultralytics.com",
|
||||
},
|
||||
buffer,
|
||||
)
|
||||
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
||||
|
||||
# Save last and best
|
||||
torch.save(ckpt, self.last)
|
||||
# Save checkpoints
|
||||
self.last.write_bytes(serialized_ckpt) # save last.pt
|
||||
if self.best_fitness == self.fitness:
|
||||
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt
|
||||
self.best.write_bytes(serialized_ckpt) # save best.pt
|
||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt
|
||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(data):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue