Search for model metadata with TensorFlow GraphDef (#13389)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
22dec59b57
commit
fd854a7c68
1 changed files with 6 additions and 2 deletions
|
|
@ -320,6 +320,10 @@ class AutoBackend(nn.Module):
|
||||||
with open(w, "rb") as f:
|
with open(w, "rb") as f:
|
||||||
gd.ParseFromString(f.read())
|
gd.ParseFromString(f.read())
|
||||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||||
|
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file
|
||||||
|
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
||||||
|
except StopIteration:
|
||||||
|
pass # no metadata file found
|
||||||
|
|
||||||
# TFLite or TFLite Edge TPU
|
# TFLite or TFLite Edge TPU
|
||||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||||
|
|
@ -402,7 +406,7 @@ class AutoBackend(nn.Module):
|
||||||
# Load external metadata YAML
|
# Load external metadata YAML
|
||||||
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
||||||
metadata = yaml_load(metadata)
|
metadata = yaml_load(metadata)
|
||||||
if metadata:
|
if metadata and isinstance(metadata, dict):
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
if k in {"stride", "batch"}:
|
if k in {"stride", "batch"}:
|
||||||
metadata[k] = int(v)
|
metadata[k] = int(v)
|
||||||
|
|
@ -563,7 +567,7 @@ class AutoBackend(nn.Module):
|
||||||
y = [y]
|
y = [y]
|
||||||
elif self.pb: # GraphDef
|
elif self.pb: # GraphDef
|
||||||
y = self.frozen_func(x=self.tf.constant(im))
|
y = self.frozen_func(x=self.tf.constant(im))
|
||||||
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
if (self.task == "segment" or len(y) == 2) and len(self.names) == 999: # segments and names not defined
|
||||||
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
||||||
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
||||||
self.names = {i: f"class{i}" for i in range(nc)}
|
self.names = {i: f"class{i}" for i in range(nc)}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue