Add Hyperparameter evolution Tuner() class (#4599)

This commit is contained in:
Glenn Jocher 2023-08-29 02:42:01 +02:00 committed by GitHub
parent 7e99804263
commit 4bd62a299c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 403 additions and 91 deletions

View file

@ -344,6 +344,25 @@ class Model:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
return self.metrics
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
"""
Runs hyperparameter tuning, optionally using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
"""
self._check_is_pytorch_model()
if use_ray:
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, max_samples=iterations, *args, **kwargs)
else:
from .tuner import Tuner
custom = {} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'export'} # highest priority args on the right
return Tuner(args=args, _callbacks=self.callbacks)(model=self.model, iterations=iterations)
def to(self, device):
"""
@ -356,20 +375,6 @@ class Model:
self.model.to(device)
return self
def tune(self, *args, **kwargs):
"""
Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
Returns:
(dict): A dictionary containing the results of the hyperparameter search.
Raises:
ModuleNotFoundError: If Ray Tune is not installed.
"""
self._check_is_pytorch_model()
from ultralytics.utils.tuner import run_ray_tune
return run_ray_tune(self, *args, **kwargs)
@property
def names(self):
"""Returns class names of the loaded model."""