Use tensorflow_lite_support (#13042)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e71efd4830
commit
1a4ac2c6ba
1 changed files with 21 additions and 16 deletions
|
|
@ -1015,12 +1015,17 @@ class Exporter:
|
||||||
|
|
||||||
def _add_tflite_metadata(self, file):
|
def _add_tflite_metadata(self, file):
|
||||||
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
"""Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata."""
|
||||||
from tflite_support import flatbuffers # noqa
|
import flatbuffers
|
||||||
from tflite_support import metadata as _metadata # noqa
|
|
||||||
from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa
|
if MACOS: # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845
|
||||||
|
from tflite_support import metadata # noqa
|
||||||
|
from tflite_support import metadata_schema_py_generated as schema # noqa
|
||||||
|
else:
|
||||||
|
from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa
|
||||||
|
from tensorflow_lite_support.metadata.python import metadata # noqa
|
||||||
|
|
||||||
# Create model info
|
# Create model info
|
||||||
model_meta = _metadata_fb.ModelMetadataT()
|
model_meta = schema.ModelMetadataT()
|
||||||
model_meta.name = self.metadata["description"]
|
model_meta.name = self.metadata["description"]
|
||||||
model_meta.version = self.metadata["version"]
|
model_meta.version = self.metadata["version"]
|
||||||
model_meta.author = self.metadata["author"]
|
model_meta.author = self.metadata["author"]
|
||||||
|
|
@ -1031,41 +1036,41 @@ class Exporter:
|
||||||
with open(tmp_file, "w") as f:
|
with open(tmp_file, "w") as f:
|
||||||
f.write(str(self.metadata))
|
f.write(str(self.metadata))
|
||||||
|
|
||||||
label_file = _metadata_fb.AssociatedFileT()
|
label_file = schema.AssociatedFileT()
|
||||||
label_file.name = tmp_file.name
|
label_file.name = tmp_file.name
|
||||||
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
|
label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS
|
||||||
|
|
||||||
# Create input info
|
# Create input info
|
||||||
input_meta = _metadata_fb.TensorMetadataT()
|
input_meta = schema.TensorMetadataT()
|
||||||
input_meta.name = "image"
|
input_meta.name = "image"
|
||||||
input_meta.description = "Input image to be detected."
|
input_meta.description = "Input image to be detected."
|
||||||
input_meta.content = _metadata_fb.ContentT()
|
input_meta.content = schema.ContentT()
|
||||||
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
|
input_meta.content.contentProperties = schema.ImagePropertiesT()
|
||||||
input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB
|
input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB
|
||||||
input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties
|
input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties
|
||||||
|
|
||||||
# Create output info
|
# Create output info
|
||||||
output1 = _metadata_fb.TensorMetadataT()
|
output1 = schema.TensorMetadataT()
|
||||||
output1.name = "output"
|
output1.name = "output"
|
||||||
output1.description = "Coordinates of detected objects, class labels, and confidence score"
|
output1.description = "Coordinates of detected objects, class labels, and confidence score"
|
||||||
output1.associatedFiles = [label_file]
|
output1.associatedFiles = [label_file]
|
||||||
if self.model.task == "segment":
|
if self.model.task == "segment":
|
||||||
output2 = _metadata_fb.TensorMetadataT()
|
output2 = schema.TensorMetadataT()
|
||||||
output2.name = "output"
|
output2.name = "output"
|
||||||
output2.description = "Mask protos"
|
output2.description = "Mask protos"
|
||||||
output2.associatedFiles = [label_file]
|
output2.associatedFiles = [label_file]
|
||||||
|
|
||||||
# Create subgraph info
|
# Create subgraph info
|
||||||
subgraph = _metadata_fb.SubGraphMetadataT()
|
subgraph = schema.SubGraphMetadataT()
|
||||||
subgraph.inputTensorMetadata = [input_meta]
|
subgraph.inputTensorMetadata = [input_meta]
|
||||||
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
|
subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1]
|
||||||
model_meta.subgraphMetadata = [subgraph]
|
model_meta.subgraphMetadata = [subgraph]
|
||||||
|
|
||||||
b = flatbuffers.Builder(0)
|
b = flatbuffers.Builder(0)
|
||||||
b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
|
||||||
metadata_buf = b.Output()
|
metadata_buf = b.Output()
|
||||||
|
|
||||||
populator = _metadata.MetadataPopulator.with_model_file(str(file))
|
populator = metadata.MetadataPopulator.with_model_file(str(file))
|
||||||
populator.load_metadata_buffer(metadata_buf)
|
populator.load_metadata_buffer(metadata_buf)
|
||||||
populator.load_associated_files([str(tmp_file)])
|
populator.load_associated_files([str(tmp_file)])
|
||||||
populator.populate()
|
populator.populate()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue