Add CoreML iOS updates (#121)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2022-12-30 21:33:43 +01:00 committed by GitHub
parent fec13ec773
commit c9f3e469cb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 215 additions and 91 deletions

View file

@ -1,5 +1,3 @@
from pathlib import Path
import torch
from ultralytics import yolo # noqa required for python usage
@ -7,9 +5,9 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
from ultralytics.yolo.utils.checks import check_yaml
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# map head: [model, trainer, validator, predictor]
MODEL_MAP = {
@ -63,7 +61,7 @@ class YOLO:
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg) # model dict
obj = cls(init_key=cls.__init_key)
obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2])
obj.task = guess_task_from_head(cfg_dict["head"][-1][-2])
obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task)
obj.model = obj.ModelClass(cfg_dict, verbose=verbose) # initialize
obj.cfg = cfg
@ -132,13 +130,7 @@ class YOLO:
overrides["mode"] = "predict"
predictor = self.PredictorClass(overrides=overrides)
# check size type
sz = predictor.args.imgsz
if type(sz) != int: # received listConfig
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
else:
predictor.args.imgsz = [sz, sz]
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
predictor.setup(model=self.model, source=source)
predictor()
@ -179,7 +171,7 @@ class YOLO:
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
args.task = self.task
exporter = Exporter(overrides=overrides)
exporter = Exporter(overrides=args)
exporter(model=self.model)
def train(self, **kwargs):
@ -230,21 +222,6 @@ class YOLO:
self.trainer.train()
@staticmethod
def _guess_task_from_head(head):
task = None
if head.lower() in ["classify", "classifier", "cls", "fc"]:
task = "classify"
if head.lower() in ["detect"]:
task = "detect"
if head.lower() in ["segment"]:
task = "segment"
if not task:
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
return task
def to(self, device):
self.model.to(device)