Check PyTorch model status for all YOLO methods (#945)

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-13 15:08:08 +04:00 committed by GitHub
parent fd5be10c66
commit 20fe708f31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 180 additions and 106 deletions

View file

@ -48,7 +48,6 @@ TensorFlow.js:
$ ln -s ../../yolov5/yolov8n_web_model public/yolov8n_web_model
$ npm start
"""
import contextlib
import json
import os
import platform
@ -74,7 +73,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, __version__, callbacks,
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode, get_latest_opset
MACOS = platform.system() == 'Darwin' # macOS environment
@ -97,6 +96,10 @@ def export_formats():
return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
EXPORT_FORMATS_LIST = list(export_formats()['Argument'][1:])
EXPORT_FORMATS_TABLE = str(export_formats())
def try_export(inner_func):
# YOLOv8 export decorator, i..e @try_export
inner_args = get_default_args(inner_func)
@ -244,7 +247,7 @@ class Exporter:
agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu()
self._add_tflite_metadata(f[8] or f[7], num_outputs=len(self.output_shape))
self._add_tflite_metadata(f[8] or f[7])
if tfjs:
f[9], _ = self._export_tfjs()
if paddle: # PaddlePaddle
@ -253,11 +256,11 @@ class Exporter:
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
f = str(Path(f[-1]))
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}"
f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}"
f"\nPredict: yolo task={model.task} mode=predict model={f}"
f"\nValidate: yolo task={model.task} mode=val model={f}"
f"\nVisualize: https://netron.app")
self.run_callbacks("on_export_end")
@ -304,7 +307,7 @@ class Exporter:
self.im.cpu() if dynamic else self.im,
f,
verbose=False,
opset_version=self.args.opset,
opset_version=self.args.opset or get_latest_opset(),
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'],
output_names=output_names,
@ -507,6 +510,10 @@ class Exporter:
# Export to TF SavedModel
subprocess.run(f'onnx2tf -i {onnx} --output_signaturedefs -o {f}', shell=True)
# Add TFLite metadata
for tflite_file in Path(f).rglob('*.tflite'):
self._add_tflite_metadata(tflite_file)
# Load saved_model
keras_model = tf.saved_model.load(f, tags=None, options=None)
@ -661,44 +668,47 @@ class Exporter:
r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity.?.?": {"name": "Identity.?.?"}}}',
r'{"outputs": {"Identity": {"name": "Identity"}, '
r'"Identity_1": {"name": "Identity_1"}, '
r'"Identity_2": {"name": "Identity_2"}, '
r'"Identity_3": {"name": "Identity_3"}}}', f_json.read_text())
r'"Identity_3": {"name": "Identity_3"}}}',
f_json.read_text(),
)
j.write(subst)
return f, None
def _add_tflite_metadata(self, file, num_outputs):
def _add_tflite_metadata(self, file):
# Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
with contextlib.suppress(ImportError):
# check_requirements('tflite_support')
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
check_requirements('tflite_support')
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(self.metadata))
from tflite_support import flatbuffers # noqa
from tflite_support import metadata as _metadata # noqa
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
tmp_file = Path('/tmp/meta.txt')
with open(tmp_file, 'w') as meta_f:
meta_f.write(str(self.metadata))
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
model_meta.subgraphMetadata = [subgraph]
model_meta = _metadata_fb.ModelMetadataT()
label_file = _metadata_fb.AssociatedFileT()
label_file.name = tmp_file.name
model_meta.associatedFiles = [label_file]
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * len(self.output_shape)
model_meta.subgraphMetadata = [subgraph]
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
b = flatbuffers.Builder(0)
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()
populator = _metadata.MetadataPopulator.with_model_file(file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files([str(tmp_file)])
populator.populate()
tmp_file.unlink()
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
# YOLOv8 CoreML pipeline