ultralytics 8.0.46 TFLite and Benchmarks updates (#1141)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-02-25 09:24:14 -08:00 committed by GitHub
parent 3765f4f6d9
commit a82ee2c779
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 130 additions and 74 deletions

View file

@ -9,8 +9,9 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -150,6 +151,13 @@ class YOLO:
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def _check_pip_update(self):
"""
Inform user of ultralytics package update availability
"""
if is_pip_package():
check_pip_update()
def reset(self):
"""
Resets the model modules.
@ -189,6 +197,10 @@ class YOLO:
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
@ -251,11 +263,12 @@ class YOLO:
Args:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
from ultralytics.yolo.utils.benchmarks import run_benchmarks
self._check_is_pytorch_model()
from ultralytics.yolo.utils.benchmarks import benchmark
overrides = self.model.args.copy()
overrides.update(kwargs)
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
def export(self, **kwargs):
"""
@ -283,6 +296,7 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
self._check_pip_update()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get('cfg'):