From fd854a7c6805ebd5c6e25917200ce5dbf1ac9a6e Mon Sep 17 00:00:00 2001 From: Burhan <62214284+Burhan-Q@users.noreply.github.com> Date: Thu, 6 Jun 2024 19:33:37 -0400 Subject: [PATCH] Search for model metadata with TensorFlow GraphDef (#13389) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/nn/autobackend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 6ec42fce..a1a10cbb 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -320,6 +320,10 @@ class AutoBackend(nn.Module): with open(w, "rb") as f: gd.ParseFromString(f.read()) frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass # no metadata file found # TFLite or TFLite Edge TPU elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python @@ -402,7 +406,7 @@ class AutoBackend(nn.Module): # Load external metadata YAML if isinstance(metadata, (str, Path)) and Path(metadata).exists(): metadata = yaml_load(metadata) - if metadata: + if metadata and isinstance(metadata, dict): for k, v in metadata.items(): if k in {"stride", "batch"}: metadata[k] = int(v) @@ -563,7 +567,7 @@ class AutoBackend(nn.Module): y = [y] elif self.pb: # GraphDef y = self.frozen_func(x=self.tf.constant(im)) - if len(y) == 2 and len(self.names) == 999: # segments and names not defined + if (self.task == "segment" or len(y) == 2) and len(self.names) == 999: # segments and names not defined ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400) self.names = {i: f"class{i}" for i in range(nc)}