diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py index 68a69632..0df109c7 100644 --- a/ultralytics/engine/tuner.py +++ b/ultralytics/engine/tuner.py @@ -101,7 +101,8 @@ class Tuner: "copy_paste": (0.0, 1.0), # segment copy-paste (probability) } self.args = get_cfg(overrides=args) - self.tune_dir = get_save_dir(self.args, name="tune") + self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune") + self.args.name = None # reset to not affect training directory self.tune_csv = self.tune_dir / "tune_results.csv" self.callbacks = _callbacks or callbacks.get_default_callbacks() self.prefix = colorstr("Tuner: ") diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py index 0b0ac765..10a77d56 100644 --- a/ultralytics/utils/tuner.py +++ b/ultralytics/utils/tuner.py @@ -1,6 +1,6 @@ # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir +from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks @@ -134,7 +134,9 @@ def run_ray_tune( tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] # Create the Ray Tune hyperparameter search tuner - tune_dir = get_save_dir(DEFAULT_CFG, name="tune").resolve() # must be absolute dir + tune_dir = get_save_dir( + get_cfg(DEFAULT_CFG, train_args), name=train_args.pop("name", "tune") + ).resolve() # must be absolute dir tune_dir.mkdir(parents=True, exist_ok=True) tuner = tune.Tuner( trainable_with_resources,