Threadpool fixes and CLI improvements (#550)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d9a0fba251
commit
21b701c4ea
22 changed files with 338 additions and 251 deletions
|
|
@ -6,7 +6,7 @@ from ultralytics import yolo # noqa
|
|||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_PATH, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ class YOLO:
|
|||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "val"
|
||||
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
|
||||
|
|
@ -169,7 +169,7 @@ class YOLO:
|
|||
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_cfg(cfg=DEFAULT_CFG_PATH, overrides=overrides)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
print(args)
|
||||
|
|
@ -181,8 +181,7 @@ class YOLO:
|
|||
Trains the model on a given dataset.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
|
|
@ -192,7 +191,7 @@ class YOLO:
|
|||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
|
||||
|
|
@ -223,6 +222,13 @@ class YOLO:
|
|||
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
"""
|
||||
Returns class names of the loaded model.
|
||||
"""
|
||||
return self.model.names
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
args.pop("project", None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue