Threadpool fixes and CLI improvements (#550)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-01-22 17:08:08 +01:00 committed by GitHub
parent d9a0fba251
commit 21b701c4ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 338 additions and 251 deletions

View file

@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
@ -151,7 +151,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
@ -169,7 +169,7 @@ class YOLO:
overrides = self.overrides.copy()
overrides.update(kwargs)
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
print(args)
@ -181,8 +181,7 @@ class YOLO:
Trains the model on a given dataset.
Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
**kwargs (Any): Any number of arguments representing the training configuration.
"""
overrides = self.overrides.copy()
overrides.update(kwargs)
@ -192,7 +191,7 @@ class YOLO:
overrides["task"] = self.task
overrides["mode"] = "train"
if not overrides.get("data"):
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get("resume"):
overrides["resume"] = self.ckpt_path
@ -223,6 +222,13 @@ class YOLO:
return model_class, trainer_class, validator_class, predictor_class
@property
def names(self):
"""
Returns class names of the loaded model.
"""
return self.model.names
@staticmethod
def _reset_ckpt_args(args):
args.pop("project", None)