Faster best.pt checkpoint saving (#9436)

This commit is contained in:
Glenn Jocher 2024-03-31 05:01:30 +02:00 committed by GitHub
parent ea80b14d72
commit eed01a2cf3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -500,9 +500,9 @@ class BaseTrainer:
# Save last and best
torch.save(ckpt, self.last)
if self.best_fitness == self.fitness:
torch.save(ckpt, self.best)
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
torch.save(ckpt, self.wdir / f"epoch{self.epoch}.pt")
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt
@staticmethod
def get_dataset(data):