Add new get_save_dir() function (#4602)
This commit is contained in:
parent
1121ef2409
commit
23b4f697c9
6 changed files with 35 additions and 39 deletions
|
|
@ -8,9 +8,9 @@ from pathlib import Path
|
|||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, SETTINGS, SETTINGS_YAML,
|
||||
IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn, yaml_load,
|
||||
yaml_print)
|
||||
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_PATH, LOGGER, RANK, SETTINGS,
|
||||
SETTINGS_YAML, IterableSimpleNamespace, __version__, checks, colorstr, deprecation_warn,
|
||||
yaml_load, yaml_print)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark'
|
||||
|
|
@ -146,8 +146,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def get_save_dir(args):
|
||||
"""Return save_dir as created from train/val/predict arguments."""
|
||||
|
||||
if getattr(args, 'save_dir', None):
|
||||
save_dir = args.save_dir
|
||||
else:
|
||||
from ultralytics.utils.files import increment_path
|
||||
|
||||
project = args.project or Path(SETTINGS['runs_dir']) / args.task
|
||||
name = args.name or f'{args.mode}'
|
||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
||||
|
||||
return Path(save_dir)
|
||||
|
||||
|
||||
def _handle_deprecation(custom):
|
||||
"""Hardcoded function to handle deprecated config keys"""
|
||||
"""Hardcoded function to handle deprecated config keys."""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
if key == 'hide_labels':
|
||||
|
|
@ -171,6 +186,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
|||
Args:
|
||||
custom (dict): a dictionary of custom configuration options
|
||||
base (dict): a dictionary of base configuration options
|
||||
e (Error, optional): An optional error that is passed by the calling function.
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue