Implement all missing docstrings (#5298)
Co-authored-by: snyk-bot <snyk-bot@snyk.io> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
e7f0658744
commit
7fd5dcbd86
26 changed files with 649 additions and 79 deletions
|
|
@ -10,6 +10,21 @@ from ultralytics.nn.modules import LayerNorm2d
|
|||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
|
||||
masks given image and prompt embeddings.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): The transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for the IoU token.
|
||||
num_mask_tokens (int): Number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for the mask tokens.
|
||||
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
|
||||
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
|
||||
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -136,7 +151,7 @@ class MaskDecoder(nn.Module):
|
|||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Lightly adapted from
|
||||
MLP (Multi-Layer Perceptron) model lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
|
|
@ -148,6 +163,16 @@ class MLP(nn.Module):
|
|||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MLP (Multi-Layer Perceptron) model.
|
||||
|
||||
Args:
|
||||
input_dim (int): The dimensionality of the input features.
|
||||
hidden_dim (int): The dimensionality of the hidden layers.
|
||||
output_dim (int): The dimensionality of the output layer.
|
||||
num_layers (int): The number of hidden layers.
|
||||
sigmoid_output (bool, optional): Whether to apply a sigmoid activation to the output layer. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,18 @@ from ultralytics.nn.modules import LayerNorm2d, MLPBlock
|
|||
|
||||
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
|
||||
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
|
||||
The encoded patches are then processed through a neck to generate the final encoded representation.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images, assumed to be square.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
|
||||
neck (nn.Sequential): Neck module to further process the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -112,6 +124,22 @@ class ImageEncoderViT(nn.Module):
|
|||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
|
||||
produces both sparse and dense embeddings for the input prompts.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
|
||||
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
|
||||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
|
||||
mask_input_size (Tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,20 @@ from .encoders import ImageEncoderViT, PromptEncoder
|
|||
|
||||
|
||||
class Sam(nn.Module):
|
||||
"""
|
||||
Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
|
||||
embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
|
||||
decoder to predict object masks.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_format (str): Format of the input image, default is 'RGB'.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
|
||||
pixel_mean (List[float]): Mean pixel values for image normalization.
|
||||
pixel_std (List[float]): Standard deviation values for image normalization.
|
||||
"""
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = 'RGB'
|
||||
|
||||
|
|
@ -28,18 +42,19 @@ class Sam(nn.Module):
|
|||
pixel_std: List[float] = (58.395, 57.12, 57.375)
|
||||
) -> None:
|
||||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
Initialize the Sam class to predict object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||
efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
|
||||
(123.675, 116.28, 103.53).
|
||||
pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
|
||||
(58.395, 57.12, 57.375).
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from ultralytics.utils.instance import to_2tuple
|
|||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization."""
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
|
|
@ -35,6 +36,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Embeds images into patches and projects them into a specified embedding dimension."""
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
|
|
@ -59,6 +61,7 @@ class PatchEmbed(nn.Module):
|
|||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
|
|
@ -96,6 +99,7 @@ class MBConv(nn.Module):
|
|||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Merges neighboring patches in the feature map and projects to a new dimension."""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
|
|
@ -130,6 +134,11 @@ class PatchMerging(nn.Module):
|
|||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -143,6 +152,20 @@ class ConvLayer(nn.Module):
|
|||
out_dim=None,
|
||||
conv_expand_ratio=4.,
|
||||
):
|
||||
"""
|
||||
Initializes the ConvLayer with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (Callable): Activation function applied after each convolution.
|
||||
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -171,6 +194,11 @@ class ConvLayer(nn.Module):
|
|||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) for transformer architectures.
|
||||
|
||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
|
|
@ -194,6 +222,14 @@ class Mlp(nn.Module):
|
|||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
|
||||
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
|
||||
grid.
|
||||
|
||||
Attributes:
|
||||
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -203,8 +239,21 @@ class Attention(torch.nn.Module):
|
|||
attn_ratio=4,
|
||||
resolution=(14, 14),
|
||||
):
|
||||
"""
|
||||
Initializes the Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads. Default is 8.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
|
||||
Raises:
|
||||
AssertionError: If `resolution` is not a tuple of length 2.
|
||||
"""
|
||||
super().__init__()
|
||||
# (h, w)
|
||||
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim ** -0.5
|
||||
|
|
@ -241,8 +290,9 @@ class Attention(torch.nn.Module):
|
|||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
B, N, _ = x.shape
|
||||
def forward(self, x): # x
|
||||
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
x = self.norm(x)
|
||||
|
|
@ -264,20 +314,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int, int]): Input resolution.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
"""
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -291,6 +328,24 @@ class TinyViTBlock(nn.Module):
|
|||
local_conv_size=3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViTBlock.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Window size for attention. Default is 7.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default is 0.
|
||||
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `window_size` is not greater than 0.
|
||||
AssertionError: If `dim` is not divisible by `num_heads`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -367,24 +422,7 @@ class TinyViTBlock(nn.Module):
|
|||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
local_conv_size (int): the kernel size of the depthwise convolution between attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
out_dim (int | optional): the output dimension of the layer. Default: None
|
||||
"""
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -402,6 +440,27 @@ class BasicLayer(nn.Module):
|
|||
activation=nn.GELU,
|
||||
out_dim=None,
|
||||
):
|
||||
"""
|
||||
Initializes the BasicLayer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
out_dim (int | None, optional): The output dimension of the layer. Default is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -456,6 +515,30 @@ class LayerNorm2d(nn.Module):
|
|||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
The TinyViT architecture for vision tasks.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
in_chans (int): Number of input channels.
|
||||
num_classes (int): Number of classification classes.
|
||||
embed_dims (List[int]): List of embedding dimensions for each layer.
|
||||
depths (List[int]): List of depths for each layer.
|
||||
num_heads (List[int]): List of number of attention heads for each layer.
|
||||
window_sizes (List[int]): List of window sizes for each layer.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_rate (float): Dropout rate for drop layers.
|
||||
drop_path_rate (float): Drop path rate for stochastic depth.
|
||||
use_checkpoint (bool): Use checkpointing for efficient memory usage.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int): Local convolution kernel size.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay.
|
||||
|
||||
Note:
|
||||
This implementation is generalized to accept a list of depths, attention heads,
|
||||
embedding dimensions and window sizes, which allows you to create a
|
||||
"stack" of TinyViT models of varying configurations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -474,6 +557,25 @@ class TinyViT(nn.Module):
|
|||
local_conv_size=3,
|
||||
layer_lr_decay=1.0,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViT model.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): The input image size. Defaults to 224.
|
||||
in_chans (int, optional): Number of input channels. Defaults to 3.
|
||||
num_classes (int, optional): Number of classification classes. Defaults to 1000.
|
||||
embed_dims (List[int], optional): List of embedding dimensions for each layer. Defaults to [96, 192, 384, 768].
|
||||
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
|
||||
num_heads (List[int], optional): List of number of attention heads for each layer. Defaults to [3, 6, 12, 24].
|
||||
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
|
||||
drop_rate (float, optional): Dropout rate. Defaults to 0.
|
||||
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
|
||||
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.num_classes = num_classes
|
||||
|
|
|
|||
|
|
@ -10,6 +10,21 @@ from ultralytics.nn.modules import MLPBlock
|
|||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
|
||||
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
|
||||
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
|
||||
processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): The number of layers in the transformer.
|
||||
embedding_dim (int): The channel dimension for the input embeddings.
|
||||
num_heads (int): The number of heads for multihead attention.
|
||||
mlp_dim (int): The internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
|
||||
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
|
||||
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -98,6 +113,23 @@ class TwoWayTransformer(nn.Module):
|
|||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
|
||||
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
|
||||
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
|
||||
sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLPBlock): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -180,6 +212,17 @@ class Attention(nn.Module):
|
|||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the input embeddings.
|
||||
num_heads (int): The number of attention heads.
|
||||
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'num_heads' does not evenly divide the internal dimension (embedding_dim / downsample_rate).
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
|
|
@ -191,13 +234,15 @@ class Attention(nn.Module):
|
|||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
||||
@staticmethod
|
||||
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
def _recombine_heads(self, x: Tensor) -> Tensor:
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue