ultralytics 8.0.197 save P, R, F1 curves to metrics (#5354)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: erminkev1 <83356055+erminkev1@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Andy <39454881+yermandy@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-13 02:49:31 +02:00 committed by GitHub
parent 7fd5dcbd86
commit 12e3eef844
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 337 additions and 195 deletions

View file

@ -45,6 +45,7 @@ def run_ray_tune(model,
try:
subprocess.run('pip install ray[tune]'.split(), check=True)
import ray
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
@ -83,6 +84,10 @@ def run_ray_tune(model,
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability)
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability)
# Put the model in ray store
task = model.task
model_in_store = ray.put(model)
def _tune(config):
"""
Trains the YOLO model with the specified hyperparameters and additional arguments.
@ -93,9 +98,10 @@ def run_ray_tune(model,
Returns:
None.
"""
model.reset_callbacks()
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
model_to_train.reset_callbacks()
config.update(train_args)
results = model.train(**config)
results = model_to_train.train(**config)
return results.results_dict
# Get search space
@ -104,7 +110,7 @@ def run_ray_tune(model,
LOGGER.warning('WARNING ⚠️ search space not provided, using default search space.')
# Get dataset
data = train_args.get('data', TASK2DATA[model.task])
data = train_args.get('data', TASK2DATA[task])
space['data'] = data
if 'data' not in train_args:
LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".')
@ -114,7 +120,7 @@ def run_ray_tune(model,
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(time_attr='epoch',
metric=TASK2METRIC[model.task],
metric=TASK2METRIC[task],
mode='max',
max_t=train_args.get('epochs') or DEFAULT_CFG_DICT['epochs'] or 100,
grace_period=grace_period,