Refactor Python code (#13448)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-06-09 02:32:17 +02:00 committed by GitHub
parent 6a234f3639
commit 1b26838def
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 81 additions and 101 deletions

View file

@ -320,10 +320,8 @@ class AutoBackend(nn.Module):
with open(w, "rb") as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
try: # attempt to retrieve metadata from SavedModel file potentially alongside GraphDef file
with contextlib.suppress(StopIteration): # find metadata in SavedModel alongside GraphDef
metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml"))
except StopIteration:
pass # no metadata file found
# TFLite or TFLite Edge TPU
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python

View file

@ -666,8 +666,7 @@ class CBLinear(nn.Module):
def forward(self, x):
"""Forward pass through CBLinear layer."""
outs = self.conv(x).split(self.c2s, dim=1)
return outs
return self.conv(x).split(self.c2s, dim=1)
class CBFuse(nn.Module):
@ -682,5 +681,4 @@ class CBFuse(nn.Module):
"""Forward pass through CBFuse layer."""
target_size = xs[-1].shape[2:]
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
return out
return torch.sum(torch.stack(res + xs[-1:]), dim=0)