Tests and docstrings improvements (#4475)

This commit is contained in:
Glenn Jocher 2023-08-21 17:02:14 +02:00 committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 107 additions and 186 deletions

View file

@ -35,8 +35,7 @@ class Conv2d_BN(torch.nn.Sequential):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
w.size(0),
w.shape[2:],
@ -72,8 +71,7 @@ class PatchEmbed(nn.Module):
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
@ -110,21 +108,14 @@ class MBConv(nn.Module):
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act3(x)
return x
return self.act3(x)
class PatchMerging(nn.Module):
@ -137,9 +128,7 @@ class PatchMerging(nn.Module):
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 2
if (out_dim == 320 or out_dim == 448 or out_dim == 576):
stride_c = 1
stride_c = 1 if out_dim in [320, 448, 576] else 2
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
@ -156,8 +145,7 @@ class PatchMerging(nn.Module):
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
return x.flatten(2).transpose(1, 2)
class ConvLayer(nn.Module):
@ -174,7 +162,6 @@ class ConvLayer(nn.Module):
out_dim=None,
conv_expand_ratio=4.,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -192,20 +179,13 @@ class ConvLayer(nn.Module):
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
class Mlp(nn.Module):
@ -222,13 +202,11 @@ class Mlp(nn.Module):
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
return self.drop(x)
class Attention(torch.nn.Module):
@ -297,12 +275,12 @@ class Attention(torch.nn.Module):
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
return self.proj(x)
class TinyViTBlock(nn.Module):
r""" TinyViT Block.
"""
TinyViT Block.
Args:
dim (int): Number of input channels.
@ -312,8 +290,7 @@ class TinyViTBlock(nn.Module):
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
local_conv_size (int): the kernel size of the convolution between
Attention and MLP. Default: 3
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
activation (torch.nn): the activation function. Default: nn.GELU
"""
@ -391,8 +368,7 @@ class TinyViTBlock(nn.Module):
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
@ -400,7 +376,8 @@ class TinyViTBlock(nn.Module):
class BasicLayer(nn.Module):
""" A basic TinyViT layer for one stage.
"""
A basic TinyViT layer for one stage.
Args:
dim (int): Number of input channels.
@ -434,7 +411,6 @@ class BasicLayer(nn.Module):
activation=nn.GELU,
out_dim=None,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@ -456,20 +432,13 @@ class BasicLayer(nn.Module):
) for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
else:
self.downsample = None
self.downsample = None if downsample is None else downsample(
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
@ -487,8 +456,7 @@ class LayerNorm2d(nn.Module):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
return self.weight[:, None, None] * x + self.bias[:, None, None]
class TinyViT(nn.Module):
@ -548,10 +516,7 @@ class TinyViT(nn.Module):
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(
conv_expand_ratio=mbconv_expand_ratio,
**kwargs,
)
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
layer = BasicLayer(num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
@ -622,7 +587,7 @@ class TinyViT(nn.Module):
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
@ -645,9 +610,7 @@ class TinyViT(nn.Module):
B, _, C = x.size()
x = x.view(B, 64, 64, C)
x = x.permute(0, 3, 1, 2)
x = self.neck(x)
return x
return self.neck(x)
def forward(self, x):
x = self.forward_features(x)
return x
return self.forward_features(x)