Search for model metadata with TensorFlow GraphDef (#13389)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
22dec59b57
commit
fd854a7c68
1 changed files with 6 additions and 2 deletions
|
|
@ -320,6 +320,10 @@ class AutoBackend(nn.Module):
|
|||
with open(w, "rb") as f:
|
||||
gd.ParseFromString(f.read())
|
||||
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
|
||||
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file
|
||||
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
|
||||
except StopIteration:
|
||||
pass # no metadata file found
|
||||
|
||||
# TFLite or TFLite Edge TPU
|
||||
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
|
||||
|
|
@ -402,7 +406,7 @@ class AutoBackend(nn.Module):
|
|||
# Load external metadata YAML
|
||||
if isinstance(metadata, (str, Path)) and Path(metadata).exists():
|
||||
metadata = yaml_load(metadata)
|
||||
if metadata:
|
||||
if metadata and isinstance(metadata, dict):
|
||||
for k, v in metadata.items():
|
||||
if k in {"stride", "batch"}:
|
||||
metadata[k] = int(v)
|
||||
|
|
@ -563,7 +567,7 @@ class AutoBackend(nn.Module):
|
|||
y = [y]
|
||||
elif self.pb: # GraphDef
|
||||
y = self.frozen_func(x=self.tf.constant(im))
|
||||
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
|
||||
if (self.task == "segment" or len(y) == 2) and len(self.names) == 999: # segments and names not defined
|
||||
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
|
||||
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
|
||||
self.names = {i: f"class{i}" for i in range(nc)}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue