ImageNet names, classify inference, resume fixes (#712)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
aecd17d455
commit
522f1937ed
16 changed files with 1121 additions and 115 deletions
|
|
@ -21,7 +21,7 @@ from torch.optim import lr_scheduler
|
|||
from tqdm import tqdm
|
||||
|
||||
from ultralytics import __version__
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis,
|
||||
|
|
@ -515,14 +515,15 @@ class BaseTrainer:
|
|||
def check_resume(self):
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
last = Path(check_file(resume) if isinstance(resume, (str, Path)) else get_latest_run())
|
||||
args_yaml = last.parent.parent / 'args.yaml' # train options yaml
|
||||
assert args_yaml.is_file(), \
|
||||
FileNotFoundError(f'Resume checkpoint {last} not found. '
|
||||
'Please pass a valid checkpoint to resume from, i.e. yolo resume=path/to/last.pt')
|
||||
args = get_cfg(args_yaml) # replace
|
||||
args.model, resume = str(last), True # reinstate
|
||||
self.args = args
|
||||
try:
|
||||
last = Path(
|
||||
check_file(resume) if isinstance(resume, (str,
|
||||
Path)) and Path(resume).exists() else get_latest_run())
|
||||
self.args = get_cfg(attempt_load_weights(last).args)
|
||||
self.args.model, resume = str(last), True # reinstate
|
||||
except Exception as e:
|
||||
raise FileNotFoundError("Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'") from e
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
|
|
@ -541,7 +542,7 @@ class BaseTrainer:
|
|||
f'{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n' \
|
||||
f"Start a new training without --resume, i.e. 'yolo task=... mode=train model={self.args.model}'"
|
||||
LOGGER.info(
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch} to {self.epochs} total epochs')
|
||||
f'Resuming training from {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs')
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue