'best.pt' inherit all-epochs results curves from 'last.pt' (#15791)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-24 23:17:09 +08:00 committed by GitHub
parent 18bc4e85c4
commit c1882a4327
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 14 additions and 9 deletions

View file

@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import (
class BaseTrainer:
"""
BaseTrainer.
A base class for creating trainers.
Attributes:
@ -478,12 +476,16 @@ class BaseTrainer:
torch.cuda.empty_cache()
self.run_callbacks("teardown")
def read_results_csv(self):
"""Read results.csv into a dict using pandas."""
import pandas as pd # scope for faster 'import ultralytics'
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import io
import pandas as pd # scope for faster 'import ultralytics'
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO()
torch.save(
@ -496,7 +498,7 @@ class BaseTrainer:
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
"train_results": self.read_results_csv(),
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
@ -646,6 +648,9 @@ class BaseTrainer:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
if self.last.is_file(): # update best.pt train_metrics from last.pt
k = "train_results"
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)