Fix ambiguous variable names (#13864)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Alex Pasquali <alexpasquali98@gmail.com>
This commit is contained in:
Glenn Jocher 2024-06-21 21:06:37 +02:00 committed by GitHub
parent c497732278
commit ee859ac64d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 22 additions and 24 deletions

View file

@ -384,8 +384,8 @@ class TinyViTBlock(nn.Module):
convolution.
"""
h, w = self.input_resolution
b, l, c = x.shape
assert l == h * w, "input feature has wrong size"
b, hw, c = x.shape # batch, height*width, channels
assert hw == h * w, "input feature has wrong size"
res_x = x
if h == self.window_size and w == self.window_size:
x = self.attn(x)
@ -394,13 +394,13 @@ class TinyViTBlock(nn.Module):
pad_b = (self.window_size - h % self.window_size) % self.window_size
pad_r = (self.window_size - w % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = h + pad_b, w + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# Window partition
x = (
x.view(b, nH, self.window_size, nW, self.window_size, c)
@ -408,19 +408,18 @@ class TinyViTBlock(nn.Module):
.reshape(b * nH * nW, self.window_size * self.window_size, c)
)
x = self.attn(x)
# Window reverse
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
if padding:
x = x[:, :h, :w].contiguous()
x = x.view(b, l, c)
x = x.view(b, hw, c)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(b, c, h, w)
x = self.local_conv(x)
x = x.view(b, c, l).transpose(1, 2)
x = x.view(b, c, hw).transpose(1, 2)
return x + self.drop_path(self.mlp(x))