Add new get_save_dir() function (#4602)

This commit is contained in:
Glenn Jocher 2023-08-28 10:43:41 +02:00 committed by GitHub
parent 1121ef2409
commit 23b4f697c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 39 deletions

View file

@ -23,15 +23,15 @@ from torch.cuda import amp
from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm import tqdm
from ultralytics.cfg import get_cfg
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, clean_url,
colorstr, emojis, yaml_save)
from ultralytics.utils import (DEFAULT_CFG, LOGGER, RANK, TQDM_BAR_FORMAT, __version__, callbacks, clean_url, colorstr,
emojis, yaml_save)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run, increment_path
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, init_seeds, one_cycle, select_device,
strip_optimizer)
@ -91,13 +91,7 @@ class BaseTrainer:
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}'
if hasattr(self.args, 'save_dir'):
self.save_dir = Path(self.args.save_dir)
else:
self.save_dir = Path(
increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))
self.save_dir = get_save_dir(self.args)
self.wdir = self.save_dir / 'weights' # weights dir
if RANK in (-1, 0):
self.wdir.mkdir(parents=True, exist_ok=True) # make dir