Add new get_save_dir() function (#4602)

This commit is contained in:
Glenn Jocher 2023-08-28 10:43:41 +02:00 committed by GitHub
parent 1121ef2409
commit 23b4f697c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 39 deletions

View file

@ -8,9 +8,9 @@ from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Union
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
yaml_print)
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
yaml_load, yaml_print)
# Define valid tasks and modes
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
return IterableSimpleNamespace(**cfg)
def get_save_dir(args):
"""Return save_dir as created from train/val/predict arguments."""
if getattr(args, 'save_dir', None):
save_dir = args.save_dir
else:
from ultralytics.utils.files import increment_path
project = args.project or Path(SETTINGS['runs_dir']) / args.task
name = args.name or f'{args.mode}'
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
return Path(save_dir)
def _handle_deprecation(custom):
"""Hardcoded function to handle deprecated config keys"""
"""Hardcoded function to handle deprecated config keys."""
for key in custom.copy().keys():
if key == 'hide_labels':
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
Args:
custom (dict): a dictionary of custom configuration options
base (dict): a dictionary of base configuration options
e (Error, optional): An optional error that is passed by the calling function.
"""
custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))