Check PyTorch model status for all YOLO methods (#945)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
fd5be10c66
commit
20fe708f31
21 changed files with 180 additions and 106 deletions
|
|
@ -6,11 +6,11 @@ from typing import List
|
|||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||
guess_model_task)
|
||||
guess_model_task, nn)
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
|
|
@ -55,19 +55,16 @@ class YOLO:
|
|||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics_data = None
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
try:
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
else:
|
||||
self._load(model)
|
||||
except Exception as e:
|
||||
raise NotImplementedError(f"Unable to load model='{model}'. "
|
||||
f"As an example try model='yolov8n.pt' or model='yolov8n.yaml'") from e
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
else:
|
||||
self._load(model)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
|
@ -100,15 +97,27 @@ class YOLO:
|
|||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
else:
|
||||
check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
"""
|
||||
Raises TypeError is model is not a PyTorch model
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"model='{self.model}' must be a PyTorch model, but is a different type. PyTorch models "
|
||||
f"can be used to train, val, predict and export, i.e. "
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the model modules.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
|
|
@ -122,9 +131,11 @@ class YOLO:
|
|||
Args:
|
||||
verbose (bool): Controls verbosity.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.info(verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
|
|
@ -176,6 +187,8 @@ class YOLO:
|
|||
|
||||
validator = self.ValidatorClass(args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
|
|
@ -186,7 +199,7 @@ class YOLO:
|
|||
Args:
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
|
|
@ -196,7 +209,7 @@ class YOLO:
|
|||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
return exporter(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
|
@ -205,6 +218,7 @@ class YOLO:
|
|||
Args:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get("cfg"):
|
||||
|
|
@ -226,6 +240,7 @@ class YOLO:
|
|||
if RANK in {0, -1}:
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics_data = self.trainer.validator.metrics
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
|
|
@ -234,15 +249,14 @@ class YOLO:
|
|||
Args:
|
||||
device (str): device
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.to(device)
|
||||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
|
|
@ -250,7 +264,7 @@ class YOLO:
|
|||
"""
|
||||
Returns class names of the loaded model.
|
||||
"""
|
||||
return self.model.names
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
|
|
@ -259,6 +273,16 @@ class YOLO:
|
|||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""
|
||||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
"""
|
||||
|
|
@ -269,5 +293,5 @@ class YOLO:
|
|||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots':
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||
args.pop(arg, None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue