ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -374,9 +374,9 @@ class AutoBackend(nn.Module):
|
|||
metadata = yaml_load(metadata)
|
||||
if metadata:
|
||||
for k, v in metadata.items():
|
||||
if k in ("stride", "batch"):
|
||||
if k in {"stride", "batch"}:
|
||||
metadata[k] = int(v)
|
||||
elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
|
||||
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
|
||||
metadata[k] = eval(v)
|
||||
stride = metadata["stride"]
|
||||
task = metadata["task"]
|
||||
|
|
@ -531,8 +531,8 @@ class AutoBackend(nn.Module):
|
|||
self.names = {i: f"class{i}" for i in range(nc)}
|
||||
else: # Lite or Edge TPU
|
||||
details = self.input_details[0]
|
||||
integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
||||
if integer:
|
||||
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
|
||||
if is_int:
|
||||
scale, zero_point = details["quantization"]
|
||||
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
||||
self.interpreter.set_tensor(details["index"], im)
|
||||
|
|
@ -540,7 +540,7 @@ class AutoBackend(nn.Module):
|
|||
y = []
|
||||
for output in self.output_details:
|
||||
x = self.interpreter.get_tensor(output["index"])
|
||||
if integer:
|
||||
if is_int:
|
||||
scale, zero_point = output["quantization"]
|
||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class SpatialAttention(nn.Module):
|
|||
def __init__(self, kernel_size=7):
|
||||
"""Initialize Spatial-attention module with kernel size argument."""
|
||||
super().__init__()
|
||||
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
|
||||
assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.act = nn.Sigmoid()
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ class Detect(nn.Module):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
|
|
@ -230,13 +230,13 @@ class WorldDetect(Detect):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
|
|
|
|||
|
|
@ -896,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
) # num heads
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
|
||||
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
elif m is AIFI:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue