ultralytics 8.2.73 Meta SAM2 Refactor (#14867)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
bea4c93278
commit
5d9046abda
44 changed files with 4542 additions and 3624 deletions
|
|
@ -17,16 +17,40 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization."""
|
||||
"""
|
||||
A sequential container that performs 2D convolution followed by batch normalization.
|
||||
|
||||
Attributes:
|
||||
c (torch.nn.Conv2d): 2D convolution layer.
|
||||
1 (torch.nn.BatchNorm2d): Batch normalization layer.
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the Conv2d_BN with specified parameters.
|
||||
|
||||
Args:
|
||||
a (int): Number of input channels.
|
||||
b (int): Number of output channels.
|
||||
ks (int): Kernel size for the convolution. Defaults to 1.
|
||||
stride (int): Stride for the convolution. Defaults to 1.
|
||||
pad (int): Padding for the convolution. Defaults to 0.
|
||||
dilation (int): Dilation factor for the convolution. Defaults to 1.
|
||||
groups (int): Number of groups for the convolution. Defaults to 1.
|
||||
bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output = conv_bn(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
"""Initializes a sequential container with 2D convolution followed by batch normalization."""
|
||||
super().__init__()
|
||||
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
|
|
@ -36,12 +60,29 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Embeds images into patches and projects them into a specified embedding dimension."""
|
||||
"""
|
||||
Embeds images into patches and projects them into a specified embedding dimension.
|
||||
|
||||
Attributes:
|
||||
patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
|
||||
num_patches (int): Total number of patches.
|
||||
in_chans (int): Number of input channels.
|
||||
embed_dim (int): Dimension of the embedding.
|
||||
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input tensor through the patch embedding sequence.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> output = patch_embed(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
"""Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
|
|
@ -56,17 +97,40 @@ class PatchEmbed(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
"""Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
|
||||
"""
|
||||
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
||||
|
||||
Attributes:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
out_chans (int): Number of output channels.
|
||||
conv1 (Conv2d_BN): First convolutional layer.
|
||||
act1 (nn.Module): First activation function.
|
||||
conv2 (Conv2d_BN): Depthwise convolutional layer.
|
||||
act2 (nn.Module): Second activation function.
|
||||
conv3 (Conv2d_BN): Final convolutional layer.
|
||||
act3 (nn.Module): Third activation function.
|
||||
drop_path (nn.Module): Drop path layer (Identity for inference).
|
||||
|
||||
Methods:
|
||||
forward: Performs the forward pass through the MBConv layer.
|
||||
|
||||
Examples:
|
||||
>>> in_chans, out_chans = 32, 64
|
||||
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
|
||||
>>> x = torch.randn(1, in_chans, 56, 56)
|
||||
>>> output = mbconv(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
"""Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
|
|
@ -86,7 +150,7 @@ class MBConv(nn.Module):
|
|||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
"""Implements the forward pass of MBConv, applying convolutions and skip connection."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
|
|
@ -99,12 +163,34 @@ class MBConv(nn.Module):
|
|||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Merges neighboring patches in the feature map and projects to a new dimension."""
|
||||
"""
|
||||
Merges neighboring patches in the feature map and projects to a new dimension.
|
||||
|
||||
This class implements a patch merging operation that combines spatial information and adjusts the feature
|
||||
dimension. It uses a series of convolutional layers with batch normalization to achieve this.
|
||||
|
||||
Attributes:
|
||||
input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
|
||||
dim (int): The input dimension of the feature map.
|
||||
out_dim (int): The output dimension after merging and projection.
|
||||
act (nn.Module): The activation function used between convolutions.
|
||||
conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
|
||||
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
|
||||
conv3 (Conv2d_BN): The third convolutional layer for final projection.
|
||||
|
||||
Methods:
|
||||
forward: Applies the patch merging operation to the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> input_resolution = (56, 56)
|
||||
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
|
||||
>>> x = torch.randn(4, 64, 56, 56)
|
||||
>>> output = patch_merging(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
"""Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -117,7 +203,7 @@ class PatchMerging(nn.Module):
|
|||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
"""Applies patch merging and dimension projection to the input feature map."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
|
|
@ -137,7 +223,24 @@ class ConvLayer(nn.Module):
|
|||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
|
||||
This layer optionally applies downsample operations to the output and supports gradient checkpointing.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Resolution of the input image.
|
||||
depth (int): Number of MBConv layers in the block.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of MBConv layers.
|
||||
downsample (Optional[Callable]): Function for downsampling the output.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the convolutional layers.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
||||
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
||||
>>> output = conv_layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -155,16 +258,25 @@ class ConvLayer(nn.Module):
|
|||
"""
|
||||
Initializes the ConvLayer with the given dimensions and settings.
|
||||
|
||||
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
|
||||
optionally applies downsampling to the output.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (Callable): Activation function applied after each convolution.
|
||||
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
||||
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
||||
>>> output = conv_layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -194,7 +306,7 @@ class ConvLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
"""Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
|
@ -202,13 +314,33 @@ class ConvLayer(nn.Module):
|
|||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) for transformer architectures.
|
||||
Multi-layer Perceptron (MLP) module for transformer architectures.
|
||||
|
||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||
This module applies layer normalization, two fully-connected layers with an activation function in between,
|
||||
and dropout. It is commonly used in transformer-based architectures.
|
||||
|
||||
Attributes:
|
||||
norm (nn.LayerNorm): Layer normalization applied to the input.
|
||||
fc1 (nn.Linear): First fully-connected layer.
|
||||
fc2 (nn.Linear): Second fully-connected layer.
|
||||
act (nn.Module): Activation function applied after the first fully-connected layer.
|
||||
drop (nn.Dropout): Dropout layer applied after the activation function.
|
||||
|
||||
Methods:
|
||||
forward: Applies the MLP operations on the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
>>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
|
||||
>>> x = torch.randn(32, 100, 256)
|
||||
>>> output = mlp(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([32, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
"""Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
|
@ -219,7 +351,7 @@ class Mlp(nn.Module):
|
|||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
"""Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
|
|
@ -230,12 +362,37 @@ class Mlp(nn.Module):
|
|||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
|
||||
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
|
||||
grid.
|
||||
Multi-head attention module with spatial awareness and trainable attention biases.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Attributes:
|
||||
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
|
||||
num_heads (int): Number of attention heads.
|
||||
scale (float): Scaling factor for attention scores.
|
||||
key_dim (int): Dimensionality of the keys and queries.
|
||||
nh_kd (int): Product of num_heads and key_dim.
|
||||
d (int): Dimensionality of the value vectors.
|
||||
dh (int): Product of d and num_heads.
|
||||
attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
|
||||
norm (nn.LayerNorm): Layer normalization applied to input.
|
||||
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
|
||||
proj (nn.Linear): Linear layer for final projection.
|
||||
attention_biases (nn.Parameter): Learnable attention biases.
|
||||
attention_bias_idxs (Tensor): Indices for attention biases.
|
||||
ab (Tensor): Cached attention biases for inference, deleted during training.
|
||||
|
||||
Methods:
|
||||
train: Sets the module in training mode and handles the 'ab' attribute.
|
||||
forward: Performs the forward pass of the attention mechanism.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
||||
>>> x = torch.randn(1, 196, 256)
|
||||
>>> output = attn(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -247,17 +404,28 @@ class Attention(torch.nn.Module):
|
|||
resolution=(14, 14),
|
||||
):
|
||||
"""
|
||||
Initializes the Attention module.
|
||||
Initializes the Attention module for multi-head attention with spatial awareness.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads. Default is 8.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
num_heads (int): Number of attention heads. Default is 8.
|
||||
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
|
||||
Raises:
|
||||
AssertionError: If `resolution` is not a tuple of length 2.
|
||||
AssertionError: If 'resolution' is not a tuple of length 2.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
||||
>>> x = torch.randn(1, 196, 256)
|
||||
>>> output = attn(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 256])
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -290,7 +458,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
"""Performs multi-head attention with spatial awareness and trainable attention biases."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, "ab"):
|
||||
del self.ab
|
||||
|
|
@ -298,7 +466,7 @@ class Attention(torch.nn.Module):
|
|||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x
|
||||
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
|
||||
"""Applies multi-head attention with spatial awareness and trainable attention biases."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
|
|
@ -322,7 +490,34 @@ class Attention(torch.nn.Module):
|
|||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input."""
|
||||
"""
|
||||
TinyViT Block that applies self-attention and a local convolution to the input.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Size of the attention window.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
|
||||
attn (Attention): Self-attention module.
|
||||
mlp (Mlp): Multi-layer perceptron module.
|
||||
local_conv (Conv2d_BN): Depth-wise local convolution layer.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the TinyViT block.
|
||||
extra_repr: Returns a string with extra information about the block's parameters.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 196, 192)
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
||||
>>> output = block(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 192])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -337,22 +532,32 @@ class TinyViTBlock(nn.Module):
|
|||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViTBlock.
|
||||
Initializes a TinyViT block with self-attention and local convolution.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Window size for attention. Default is 7.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default is 0.
|
||||
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
window_size (int): Size of the attention window. Must be greater than 0.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float): Dropout rate.
|
||||
drop_path (float): Stochastic depth rate.
|
||||
local_conv_size (int): Kernel size of the local convolution.
|
||||
activation (torch.nn.Module): Activation function for MLP.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `window_size` is not greater than 0.
|
||||
AssertionError: If `dim` is not divisible by `num_heads`.
|
||||
AssertionError: If window_size is not greater than 0.
|
||||
AssertionError: If dim is not divisible by num_heads.
|
||||
|
||||
Examples:
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
||||
>>> input_tensor = torch.randn(1, 196, 192)
|
||||
>>> output = block(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 192])
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -380,9 +585,7 @@ class TinyViTBlock(nn.Module):
|
|||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
"""Applies self-attention, local convolution, and MLP operations to the input tensor."""
|
||||
h, w = self.input_resolution
|
||||
b, hw, c = x.shape # batch, height*width, channels
|
||||
assert hw == h * w, "input feature has wrong size"
|
||||
|
|
@ -424,8 +627,19 @@ class TinyViTBlock(nn.Module):
|
|||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
Returns a string representation of the TinyViTBlock's parameters.
|
||||
|
||||
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
||||
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
||||
|
||||
Returns:
|
||||
(str): A formatted string containing the block's parameters.
|
||||
|
||||
Examples:
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
|
||||
>>> print(block.extra_repr())
|
||||
dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
|
||||
"""
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
|
|
@ -434,7 +648,31 @@ class TinyViTBlock(nn.Module):
|
|||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
|
||||
"""
|
||||
A basic TinyViT layer for one stage in a TinyViT architecture.
|
||||
|
||||
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
|
||||
and an optional downsampling operation.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
|
||||
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the layer's blocks and optional downsampling.
|
||||
extra_repr: Returns a string with the layer's parameters for printing.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 3136, 192)
|
||||
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> output = layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 784, 384])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -453,25 +691,34 @@ class BasicLayer(nn.Module):
|
|||
out_dim=None,
|
||||
):
|
||||
"""
|
||||
Initializes the BasicLayer.
|
||||
Initializes a BasicLayer in the TinyViT architecture.
|
||||
|
||||
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
|
||||
process feature maps at a specific resolution and dimensionality within the TinyViT model.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
out_dim (int | None, optional): The output dimension of the layer. Default is None.
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
num_heads (int): Number of attention heads in each TinyViT block.
|
||||
window_size (int): Size of the local window for attention computation.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float): Dropout rate.
|
||||
drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
|
||||
downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
|
||||
activation (nn.Module): Activation function used in the MLP.
|
||||
out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
|
||||
ValueError: If `drop_path` is a list and its length doesn't match `depth`.
|
||||
|
||||
Examples:
|
||||
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> x = torch.randn(1, 56*56, 96)
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -505,58 +752,49 @@ class BasicLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
"""Processes input through TinyViT blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
"""Returns a string with the layer's parameters for printing."""
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
The TinyViT architecture for vision tasks.
|
||||
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
||||
|
||||
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
|
||||
neural networks for improved efficiency and performance on vision tasks.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
in_chans (int): Number of input channels.
|
||||
num_classes (int): Number of classification classes.
|
||||
embed_dims (List[int]): List of embedding dimensions for each layer.
|
||||
depths (List[int]): List of depths for each layer.
|
||||
num_heads (List[int]): List of number of attention heads for each layer.
|
||||
window_sizes (List[int]): List of window sizes for each layer.
|
||||
depths (List[int]): Number of blocks in each stage.
|
||||
num_layers (int): Total number of layers in the network.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_rate (float): Dropout rate for drop layers.
|
||||
drop_path_rate (float): Drop path rate for stochastic depth.
|
||||
use_checkpoint (bool): Use checkpointing for efficient memory usage.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int): Local convolution kernel size.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
patches_resolution (Tuple[int, int]): Resolution of embedded patches.
|
||||
layers (nn.ModuleList): List of network layers.
|
||||
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
|
||||
head (nn.Linear): Linear layer for final classification.
|
||||
neck (nn.Sequential): Neck module for feature refinement.
|
||||
|
||||
Note:
|
||||
This implementation is generalized to accept a list of depths, attention heads,
|
||||
embedding dimensions and window sizes, which allows you to create a
|
||||
"stack" of TinyViT models of varying configurations.
|
||||
Methods:
|
||||
set_layer_lr_decay: Sets layer-wise learning rate decay.
|
||||
_init_weights: Initializes weights for linear and normalization layers.
|
||||
no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
|
||||
forward_features: Processes input through the feature extraction layers.
|
||||
forward: Performs a forward pass through the entire network.
|
||||
|
||||
Examples:
|
||||
>>> model = TinyViT(img_size=224, num_classes=1000)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> features = model.forward_features(x)
|
||||
>>> print(features.shape)
|
||||
torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -579,21 +817,33 @@ class TinyViT(nn.Module):
|
|||
"""
|
||||
Initializes the TinyViT model.
|
||||
|
||||
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
|
||||
attention and convolution blocks, and a classification head.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): The input image size. Defaults to 224.
|
||||
in_chans (int, optional): Number of input channels. Defaults to 3.
|
||||
num_classes (int, optional): Number of classification classes. Defaults to 1000.
|
||||
embed_dims (List[int], optional): List of embedding dimensions per layer. Defaults to [96, 192, 384, 768].
|
||||
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
|
||||
num_heads (List[int], optional): List of number of attention heads per layer. Defaults to [3, 6, 12, 24].
|
||||
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
|
||||
drop_rate (float, optional): Dropout rate. Defaults to 0.
|
||||
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
|
||||
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
|
||||
img_size (int): Size of the input image. Default is 224.
|
||||
in_chans (int): Number of input channels. Default is 3.
|
||||
num_classes (int): Number of classes for classification. Default is 1000.
|
||||
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
|
||||
Default is (96, 192, 384, 768).
|
||||
depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
|
||||
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
|
||||
Default is (3, 6, 12, 24).
|
||||
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
|
||||
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
|
||||
drop_rate (float): Dropout rate. Default is 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default is 0.1.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
|
||||
local_conv_size (int): Kernel size for local convolutions. Default is 3.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
|
||||
|
||||
Examples:
|
||||
>>> model = TinyViT(img_size=224, num_classes=1000)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> output = model(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 1000])
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
|
@ -671,7 +921,7 @@ class TinyViT(nn.Module):
|
|||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
"""Sets layer-wise learning rate decay for the TinyViT model based on depth."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# Layers -> blocks (depth)
|
||||
|
|
@ -706,7 +956,7 @@ class TinyViT(nn.Module):
|
|||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
"""Initializes weights for linear and normalization layers in the TinyViT model."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
|
|
@ -718,11 +968,11 @@ class TinyViT(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
"""Returns a set of keywords for parameters that should not use weight decay."""
|
||||
return {"attention_biases"}
|
||||
|
||||
def forward_features(self, x):
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
"""Processes input through feature extraction layers, returning spatial features."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
|
|
@ -737,5 +987,5 @@ class TinyViT(nn.Module):
|
|||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
"""Performs the forward pass through the TinyViT model, extracting features from the input image."""
|
||||
return self.forward_features(x)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue