Code Refactor for Speed and Readability (#13450)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Paula Derrenger 2024-06-09 17:38:05 +02:00 committed by GitHub
parent 1b26838def
commit 6367ff4748
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 35 additions and 28 deletions

View file

@ -383,44 +383,44 @@ class TinyViTBlock(nn.Module):
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
convolution.
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
h, w = self.input_resolution
b, l, c = x.shape
assert l == h * w, "input feature has wrong size"
res_x = x
if H == self.window_size and W == self.window_size:
if h == self.window_size and w == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size - H % self.window_size) % self.window_size
pad_r = (self.window_size - W % self.window_size) % self.window_size
x = x.view(b, h, w, c)
pad_b = (self.window_size - h % self.window_size) % self.window_size
pad_r = (self.window_size - w % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
pH, pW = h + pad_b, w + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# Window partition
x = (
x.view(B, nH, self.window_size, nW, self.window_size, C)
x.view(b, nH, self.window_size, nW, self.window_size, c)
.transpose(2, 3)
.reshape(B * nH * nW, self.window_size * self.window_size, C)
.reshape(b * nH * nW, self.window_size * self.window_size, c)
)
x = self.attn(x)
# Window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
if padding:
x = x[:, :H, :W].contiguous()
x = x[:, :h, :w].contiguous()
x = x.view(B, L, C)
x = x.view(b, l, c)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = x.transpose(1, 2).reshape(b, c, h, w)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x.view(b, c, l).transpose(1, 2)
return x + self.drop_path(self.mlp(x))
@ -565,10 +565,10 @@ class TinyViT(nn.Module):
img_size=224,
in_chans=3,
num_classes=1000,
embed_dims=[96, 192, 384, 768],
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_sizes=[7, 7, 14, 7],
embed_dims=(96, 192, 384, 768),
depths=(2, 2, 6, 2),
num_heads=(3, 6, 12, 24),
window_sizes=(7, 7, 14, 7),
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.1,
@ -732,8 +732,8 @@ class TinyViT(nn.Module):
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
B, _, C = x.shape
x = x.view(B, 64, 64, C)
batch, _, channel = x.shape
x = x.view(batch, 64, 64, channel)
x = x.permute(0, 3, 1, 2)
return self.neck(x)