Add CoreML iOS updates (#121)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
fec13ec773
commit
c9f3e469cb
13 changed files with 215 additions and 91 deletions
|
|
@ -1,5 +1,3 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
|
|
@ -7,9 +5,9 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
|
|||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
# map head: [model, trainer, validator, predictor]
|
||||
MODEL_MAP = {
|
||||
|
|
@ -63,7 +61,7 @@ class YOLO:
|
|||
cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(cfg) # model dict
|
||||
obj = cls(init_key=cls.__init_key)
|
||||
obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2])
|
||||
obj.task = guess_task_from_head(cfg_dict["head"][-1][-2])
|
||||
obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task)
|
||||
obj.model = obj.ModelClass(cfg_dict, verbose=verbose) # initialize
|
||||
obj.cfg = cfg
|
||||
|
|
@ -132,13 +130,7 @@ class YOLO:
|
|||
overrides["mode"] = "predict"
|
||||
predictor = self.PredictorClass(overrides=overrides)
|
||||
|
||||
# check size type
|
||||
sz = predictor.args.imgsz
|
||||
if type(sz) != int: # received listConfig
|
||||
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
|
||||
else:
|
||||
predictor.args.imgsz = [sz, sz]
|
||||
|
||||
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
|
||||
predictor.setup(model=self.model, source=source)
|
||||
predictor()
|
||||
|
||||
|
|
@ -179,7 +171,7 @@ class YOLO:
|
|||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
exporter = Exporter(overrides=overrides)
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
|
|
@ -230,21 +222,6 @@ class YOLO:
|
|||
|
||||
self.trainer.train()
|
||||
|
||||
@staticmethod
|
||||
def _guess_task_from_head(head):
|
||||
task = None
|
||||
if head.lower() in ["classify", "classifier", "cls", "fc"]:
|
||||
task = "classify"
|
||||
if head.lower() in ["detect"]:
|
||||
task = "detect"
|
||||
if head.lower() in ["segment"]:
|
||||
task = "segment"
|
||||
|
||||
if not task:
|
||||
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
||||
|
||||
return task
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue