ultralytics 8.0.54 TFLite export improvements and fixes (#1447)

Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-03-16 15:42:44 +01:00 committed by GitHub
parent 30fc4b537f
commit 701fba4770
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 198 additions and 166 deletions

View file

@ -8,8 +8,8 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, ONLINE, RANK, ROOT,
callbacks, is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -153,16 +153,10 @@ class YOLO:
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def _check_pip_update(self):
@smart_inference_mode()
def reset_weights(self):
"""
Inform user of ultralytics package update availability
"""
if ONLINE and is_pip_package():
check_pip_update_available()
def reset(self):
"""
Resets the model modules.
Resets the model modules parameters to randomly initialized values, losing all training information.
"""
self._check_is_pytorch_model()
for m in self.model.modules():
@ -170,6 +164,18 @@ class YOLO:
m.reset_parameters()
for p in self.model.parameters():
p.requires_grad = True
return self
@smart_inference_mode()
def load(self, weights='yolov8n.pt'):
"""
Transfers parameters with matching names and shapes from 'weights' to model.
"""
self._check_is_pytorch_model()
if isinstance(weights, (str, Path)):
weights, self.ckpt = attempt_load_one_weight(weights)
self.model.load(weights)
return self
def info(self, verbose=False):
"""
@ -299,7 +305,7 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
self._check_pip_update()
check_pip_update_available()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get('cfg'):