Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -98,7 +98,11 @@ class MaskDecoder(nn.Module):
|
|||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts masks. See 'forward' for more details."""
|
||||
"""
|
||||
Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
|
|
|
|||
|
|
@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
|
@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
|
|
@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
|
|||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
"""Gets the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
|
|
@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
|
|||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Returns the device of the first point embedding's weight tensor."""
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
|
|
@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
|
|
@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -351,6 +352,7 @@ class Block(nn.Module):
|
|||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
|
|
@ -404,6 +406,7 @@ class Attention(nn.Module):
|
|||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
|
|
@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
|||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
|
|
@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ from ultralytics.utils.instance import to_2tuple
|
|||
class Conv2d_BN(torch.nn.Sequential):
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
super().__init__()
|
||||
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
|
|
@ -34,6 +37,9 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
class PatchEmbed(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
|
|
@ -48,12 +54,16 @@ class PatchEmbed(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
|
|
@ -73,6 +83,7 @@ class MBConv(nn.Module):
|
|||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
|
|
@ -87,6 +98,9 @@ class MBConv(nn.Module):
|
|||
class PatchMerging(nn.Module):
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -99,6 +113,7 @@ class PatchMerging(nn.Module):
|
|||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
|
|
@ -149,6 +164,7 @@ class ConvLayer(nn.Module):
|
|||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
|
@ -157,6 +173,7 @@ class ConvLayer(nn.Module):
|
|||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
|
@ -167,6 +184,7 @@ class Mlp(nn.Module):
|
|||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
|
|
@ -216,6 +234,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
|
|
@ -298,6 +317,9 @@ class TinyViTBlock(nn.Module):
|
|||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
|
|
@ -337,6 +359,9 @@ class TinyViTBlock(nn.Module):
|
|||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
|
||||
|
||||
|
|
@ -402,23 +427,28 @@ class BasicLayer(nn.Module):
|
|||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
|
|
@ -518,6 +548,7 @@ class TinyViT(nn.Module):
|
|||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# layers -> blocks (depth)
|
||||
|
|
@ -525,6 +556,7 @@ class TinyViT(nn.Module):
|
|||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
|
|
@ -544,12 +576,14 @@ class TinyViT(nn.Module):
|
|||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
"""Checks if the learning rate scale attribute is present in module's parameters."""
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, 'lr_scale'), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
|
|
@ -561,11 +595,12 @@ class TinyViT(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
return {'attention_biases'}
|
||||
|
||||
def forward_features(self, x):
|
||||
# x: (N, C, H, W)
|
||||
x = self.patch_embed(x)
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
|
|
@ -579,4 +614,5 @@ class TinyViT(nn.Module):
|
|||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
return self.forward_features(x)
|
||||
|
|
|
|||
|
|
@ -21,8 +21,7 @@ class TwoWayTransformer(nn.Module):
|
|||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using
|
||||
queries whose positional embedding is supplied.
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
|
|
@ -171,8 +170,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue