New guess_model_task() function (#614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
520825c4b2
commit
59d4335664
6 changed files with 63 additions and 29 deletions
|
|
@ -3,12 +3,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||
guess_model_task)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
|
|
@ -73,9 +74,9 @@ class YOLO:
|
|||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg, append_filename=True) # model dict
|
||||
self.task = guess_task_from_model_yaml(cfg_dict)
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._guess_ops_from_task(self.task)
|
||||
self._assign_ops_from_task(self.task)
|
||||
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize
|
||||
self.cfg = cfg
|
||||
|
||||
|
|
@ -92,7 +93,7 @@ class YOLO:
|
|||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \
|
||||
self._guess_ops_from_task(self.task)
|
||||
self._assign_ops_from_task(self.task)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
|
|
@ -217,7 +218,7 @@ class YOLO:
|
|||
"""
|
||||
self.model.to(device)
|
||||
|
||||
def _guess_ops_from_task(self, task):
|
||||
def _assign_ops_from_task(self, task):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue