ultralytics 8.2.73 Meta SAM2 Refactor (#14867)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
bea4c93278
commit
5d9046abda
44 changed files with 4542 additions and 3624 deletions
|
|
@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
|
|||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
|
||||
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
|
||||
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
|
||||
processing.
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer decoder that attends to an input image using queries with
|
||||
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
|
||||
cloud processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): The number of layers in the transformer.
|
||||
embedding_dim (int): The channel dimension for the input embeddings.
|
||||
num_heads (int): The number of heads for multihead attention.
|
||||
mlp_dim (int): The internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
|
||||
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
|
||||
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Methods:
|
||||
forward: Processes image and point embeddings through the transformer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
|
|||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
||||
attention_downsample_rate (int): Downsampling rate for attention mechanism.
|
||||
|
||||
Attributes:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
|
|
@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
|
|||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Processes image and point embeddings through the Two-Way Transformer.
|
||||
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
||||
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): the processed point_embedding
|
||||
(torch.Tensor): the processed image_embedding
|
||||
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
|
|
@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
|
|||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
|
||||
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
|
||||
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
|
||||
sparse inputs.
|
||||
A two-way attention block for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
||||
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
||||
inputs to sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
self_attn (Attention): Self-attention layer for queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLPBlock): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
|
||||
mlp (MLPBlock): MLP block for transforming query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization after MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
||||
|
||||
Methods:
|
||||
forward: Applies self-attention and cross-attention to queries and keys.
|
||||
|
||||
Examples:
|
||||
>>> embedding_dim, num_heads = 256, 8
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
||||
>>> queries = torch.randn(1, 100, embedding_dim)
|
||||
>>> keys = torch.randn(1, 1000, embedding_dim)
|
||||
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
||||
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
||||
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
|
||||
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
||||
|
||||
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
||||
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
||||
of dense inputs to sparse inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
embedding_dim (int): Channel dimension of the embeddings.
|
||||
num_heads (int): Number of attention heads in the attention layers.
|
||||
mlp_dim (int): Hidden dimension of the MLP block.
|
||||
activation (Type[nn.Module]): Activation function for the MLP block.
|
||||
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
|
||||
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
||||
|
||||
Examples:
|
||||
>>> embedding_dim, num_heads = 256, 8
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
||||
>>> queries = torch.randn(1, 100, embedding_dim)
|
||||
>>> keys = torch.randn(1, 1000, embedding_dim)
|
||||
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
||||
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
||||
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
|
|
@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
|
||||
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
||||
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
|
|
@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
An attention layer with downscaling capability for embedding size after projection.
|
||||
|
||||
This class implements a multi-head attention mechanism with the option to downsample the internal
|
||||
dimension of queries, keys, and values.
|
||||
|
||||
Attributes:
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
kv_in_dim (int): Dimensionality of key and value inputs.
|
||||
internal_dim (int): Internal dimension after downsampling.
|
||||
num_heads (int): Number of attention heads.
|
||||
q_proj (nn.Linear): Linear projection for queries.
|
||||
k_proj (nn.Linear): Linear projection for keys.
|
||||
v_proj (nn.Linear): Linear projection for values.
|
||||
out_proj (nn.Linear): Linear projection for output.
|
||||
|
||||
Methods:
|
||||
_separate_heads: Separates input tensor into attention heads.
|
||||
_recombine_heads: Recombines separated attention heads.
|
||||
forward: Computes attention output for given query, key, and value tensors.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
||||
>>> q = torch.randn(1, 100, 256)
|
||||
>>> k = v = torch.randn(1, 50, 256)
|
||||
>>> output = attn(q, k, v)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -214,15 +303,27 @@ class Attention(nn.Module):
|
|||
kv_in_dim: int = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
Initializes the Attention module with specified dimensions and settings.
|
||||
|
||||
This class implements a multi-head attention mechanism with optional downsampling of the internal
|
||||
dimension for queries, keys, and values.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the input embeddings.
|
||||
num_heads (int): The number of attention heads.
|
||||
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
|
||||
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
||||
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
||||
>>> q = torch.randn(1, 100, 256)
|
||||
>>> k = v = torch.randn(1, 50, 256)
|
||||
>>> output = attn(q, k, v)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
|
@ -238,20 +339,20 @@ class Attention(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
"""Separates the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
"""Recombines separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
"""Compute the attention output given the input query, key, and value tensors."""
|
||||
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
||||
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue