Immediate checkpoint serialization (#9437)
This commit is contained in:
parent
eed01a2cf3
commit
479afce4a8
1 changed files with 26 additions and 21 deletions
|
|
@ -477,32 +477,37 @@ class BaseTrainer:
|
||||||
|
|
||||||
def save_model(self):
|
def save_model(self):
|
||||||
"""Save model training checkpoints with additional metadata."""
|
"""Save model training checkpoints with additional metadata."""
|
||||||
|
import io
|
||||||
import pandas as pd # scope for faster startup
|
import pandas as pd # scope for faster startup
|
||||||
|
|
||||||
metrics = {**self.metrics, **{"fitness": self.fitness}}
|
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
||||||
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
|
buffer = io.BytesIO()
|
||||||
ckpt = {
|
torch.save(
|
||||||
"epoch": self.epoch,
|
{
|
||||||
"best_fitness": self.best_fitness,
|
"epoch": self.epoch,
|
||||||
"model": deepcopy(de_parallel(self.model)).half(),
|
"best_fitness": self.best_fitness,
|
||||||
"ema": deepcopy(self.ema.ema).half(),
|
"model": deepcopy(de_parallel(self.model)).half(),
|
||||||
"updates": self.ema.updates,
|
"ema": deepcopy(self.ema.ema).half(),
|
||||||
"optimizer": self.optimizer.state_dict(),
|
"updates": self.ema.updates,
|
||||||
"train_args": vars(self.args), # save as dict
|
"optimizer": self.optimizer.state_dict(),
|
||||||
"train_metrics": metrics,
|
"train_args": vars(self.args), # save as dict
|
||||||
"train_results": results,
|
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
||||||
"date": datetime.now().isoformat(),
|
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
|
||||||
"version": __version__,
|
"date": datetime.now().isoformat(),
|
||||||
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
"version": __version__,
|
||||||
"docs": "https://docs.ultralytics.com",
|
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||||
}
|
"docs": "https://docs.ultralytics.com",
|
||||||
|
},
|
||||||
|
buffer,
|
||||||
|
)
|
||||||
|
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
||||||
|
|
||||||
# Save last and best
|
# Save checkpoints
|
||||||
torch.save(ckpt, self.last)
|
self.last.write_bytes(serialized_ckpt) # save last.pt
|
||||||
if self.best_fitness == self.fitness:
|
if self.best_fitness == self.fitness:
|
||||||
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt
|
self.best.write_bytes(serialized_ckpt) # save best.pt
|
||||||
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
|
||||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt
|
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_dataset(data):
|
def get_dataset(data):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue