Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -61,16 +61,14 @@ class TwoWayTransformer(nn.Module):
|
|||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape
|
||||
B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the processed point_embedding
|
||||
torch.Tensor: the processed image_embedding
|
||||
(torch.Tensor): the processed point_embedding
|
||||
(torch.Tensor): the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
|
|
@ -112,12 +110,11 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
|
||||
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
|
|
@ -175,8 +172,8 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -230,6 +227,4 @@ class Attention(nn.Module):
|
|||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
return self.out_proj(out)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue