Model interface enhancement (#106)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
38d6df55cb
commit
384f0ef1c6
6 changed files with 39 additions and 22 deletions
|
|
@ -43,6 +43,7 @@ class YOLO:
|
|||
self.trainer = None
|
||||
self.task = None
|
||||
self.ckpt = None
|
||||
self.overrides = {}
|
||||
|
||||
def new(self, cfg: str):
|
||||
"""
|
||||
|
|
@ -69,6 +70,10 @@ class YOLO:
|
|||
"""
|
||||
self.ckpt = torch.load(weights, map_location="cpu")
|
||||
self.task = self.ckpt["train_args"]["task"]
|
||||
self.overrides = dict(self.ckpt["train_args"])
|
||||
self.overrides["device"] = '' # reset device
|
||||
LOGGER.info("Device has been reset to ''")
|
||||
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||
task=self.task)
|
||||
self.model = attempt_load_weights(weights)
|
||||
|
|
@ -107,6 +112,7 @@ class YOLO:
|
|||
source (str): Accepts all source types accepted by yolo
|
||||
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
|
||||
"""
|
||||
kwargs.update(self.overrides)
|
||||
predictor = self.PredictorClass(overrides=kwargs)
|
||||
|
||||
# check size type
|
||||
|
|
@ -119,7 +125,7 @@ class YOLO:
|
|||
predictor.setup(model=self.model, source=source)
|
||||
predictor()
|
||||
|
||||
def val(self, data, **kwargs):
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset
|
||||
|
||||
|
|
@ -130,8 +136,9 @@ class YOLO:
|
|||
if not self.model:
|
||||
raise Exception("model not initialized!")
|
||||
|
||||
kwargs.update(self.overrides)
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=kwargs)
|
||||
args.data = data
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
|
||||
validator = self.ValidatorClass(args=args)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue