diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 4b66d585..cdb635e3 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -1015,12 +1015,17 @@ class Exporter: def _add_tflite_metadata(self, file): """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" - from tflite_support import flatbuffers # noqa - from tflite_support import metadata as _metadata # noqa - from tflite_support import metadata_schema_py_generated as _metadata_fb # noqa + import flatbuffers + + if MACOS: # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845 + from tflite_support import metadata # noqa + from tflite_support import metadata_schema_py_generated as schema # noqa + else: + from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa + from tensorflow_lite_support.metadata.python import metadata # noqa # Create model info - model_meta = _metadata_fb.ModelMetadataT() + model_meta = schema.ModelMetadataT() model_meta.name = self.metadata["description"] model_meta.version = self.metadata["version"] model_meta.author = self.metadata["author"] @@ -1031,41 +1036,41 @@ class Exporter: with open(tmp_file, "w") as f: f.write(str(self.metadata)) - label_file = _metadata_fb.AssociatedFileT() + label_file = schema.AssociatedFileT() label_file.name = tmp_file.name - label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS + label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS # Create input info - input_meta = _metadata_fb.TensorMetadataT() + input_meta = schema.TensorMetadataT() input_meta.name = "image" input_meta.description = "Input image to be detected." - input_meta.content = _metadata_fb.ContentT() - input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT() - input_meta.content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.RGB - input_meta.content.contentPropertiesType = _metadata_fb.ContentProperties.ImageProperties + input_meta.content = schema.ContentT() + input_meta.content.contentProperties = schema.ImagePropertiesT() + input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB + input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties # Create output info - output1 = _metadata_fb.TensorMetadataT() + output1 = schema.TensorMetadataT() output1.name = "output" output1.description = "Coordinates of detected objects, class labels, and confidence score" output1.associatedFiles = [label_file] if self.model.task == "segment": - output2 = _metadata_fb.TensorMetadataT() + output2 = schema.TensorMetadataT() output2.name = "output" output2.description = "Mask protos" output2.associatedFiles = [label_file] # Create subgraph info - subgraph = _metadata_fb.SubGraphMetadataT() + subgraph = schema.SubGraphMetadataT() subgraph.inputTensorMetadata = [input_meta] subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] model_meta.subgraphMetadata = [subgraph] b = flatbuffers.Builder(0) - b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) metadata_buf = b.Output() - populator = _metadata.MetadataPopulator.with_model_file(str(file)) + populator = metadata.MetadataPopulator.with_model_file(str(file)) populator.load_metadata_buffer(metadata_buf) populator.load_associated_files([str(tmp_file)]) populator.populate()