ultralytics 8.2.73 Meta SAM2 Refactor (#14867)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-08-05 08:53:45 +08:00 committed by GitHub
parent bea4c93278
commit 5d9046abda
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 4542 additions and 3624 deletions

View file

@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
"""
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
processing.
A Two-Way Transformer module for simultaneous attention to image and query points.
This class implements a specialized transformer decoder that attends to an input image using queries with
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
cloud processing.
Attributes:
depth (int): The number of layers in the transformer.
embedding_dim (int): The channel dimension for the input embeddings.
num_heads (int): The number of heads for multihead attention.
mlp_dim (int): The internal channel dimension for the MLP block.
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Methods:
forward: Processes image and point embeddings through the transformer.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
def __init__(
@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
Args:
depth (int): number of layers in the transformer
embedding_dim (int): the channel dimension for the input embeddings
num_heads (int): the number of heads for multihead attention. Must
divide embedding_dim
mlp_dim (int): the channel dimension internal to the MLP block
activation (nn.Module): the activation to use in the MLP block
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
mlp_dim (int): Internal channel dimension for the MLP block.
activation (Type[nn.Module]): Activation function to use in the MLP block.
attention_downsample_rate (int): Downsampling rate for attention mechanism.
Attributes:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
super().__init__()
self.depth = depth
@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Processes image and point embeddings through the Two-Way Transformer.
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
Returns:
(torch.Tensor): the processed point_embedding
(torch.Tensor): the processed image_embedding
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
class TwoWayAttentionBlock(nn.Module):
"""
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
sparse inputs.
A two-way attention block for simultaneous attention to image and query points.
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
inputs to sparse inputs.
Attributes:
self_attn (Attention): The self-attention layer for the queries.
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
self_attn (Attention): Self-attention layer for queries.
norm1 (nn.LayerNorm): Layer normalization after self-attention.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
mlp (MLPBlock): MLP block that transforms the query embeddings.
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
mlp (MLPBlock): MLP block for transforming query embeddings.
norm3 (nn.LayerNorm): Layer normalization after MLP block.
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
Methods:
forward: Applies self-attention and cross-attention to queries and keys.
Examples:
>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
def __init__(
@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
This block implements a specialized transformer layer with four main components: self-attention on sparse
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
of dense inputs to sparse inputs.
Args:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
activation (nn.Module): the activation of the mlp block
skip_first_layer_pe (bool): skip the PE on the first layer
embedding_dim (int): Channel dimension of the embeddings.
num_heads (int): Number of attention heads in the attention layers.
mlp_dim (int): Hidden dimension of the MLP block.
activation (Type[nn.Module]): Activation function for the MLP block.
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
Examples:
>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
self.skip_first_layer_pe = skip_first_layer_pe
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
"""Applies two-way attention to process query and key embeddings in a transformer block."""
# Self attention block
if self.skip_first_layer_pe:
@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
An attention layer with downscaling capability for embedding size after projection.
This class implements a multi-head attention mechanism with the option to downsample the internal
dimension of queries, keys, and values.
Attributes:
embedding_dim (int): Dimensionality of input embeddings.
kv_in_dim (int): Dimensionality of key and value inputs.
internal_dim (int): Internal dimension after downsampling.
num_heads (int): Number of attention heads.
q_proj (nn.Linear): Linear projection for queries.
k_proj (nn.Linear): Linear projection for keys.
v_proj (nn.Linear): Linear projection for values.
out_proj (nn.Linear): Linear projection for output.
Methods:
_separate_heads: Separates input tensor into attention heads.
_recombine_heads: Recombines separated attention heads.
forward: Computes attention output for given query, key, and value tensors.
Examples:
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
@ -214,15 +303,27 @@ class Attention(nn.Module):
kv_in_dim: int = None,
) -> None:
"""
Initializes the Attention model with the given dimensions and settings.
Initializes the Attention module with specified dimensions and settings.
This class implements a multi-head attention mechanism with optional downsampling of the internal
dimension for queries, keys, and values.
Args:
embedding_dim (int): The dimensionality of the input embeddings.
num_heads (int): The number of attention heads.
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
embedding_dim (int): Dimensionality of input embeddings.
num_heads (int): Number of attention heads.
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
Raises:
AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
Examples:
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
super().__init__()
self.embedding_dim = embedding_dim
@ -238,20 +339,20 @@ class Attention(nn.Module):
@staticmethod
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
"""Separate the input tensor into the specified number of attention heads."""
"""Separates the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
@staticmethod
def _recombine_heads(x: Tensor) -> Tensor:
"""Recombine the separated attention heads into a single tensor."""
"""Recombines separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
"""Compute the attention output given the input query, key, and value tensors."""
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
# Input projections
q = self.q_proj(q)