Tests and docstrings improvements (#4475)

This commit is contained in:
Glenn Jocher 2023-08-21 17:02:14 +02:00 committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 107 additions and 186 deletions

View file

@ -61,16 +61,14 @@ class TwoWayTransformer(nn.Module):
) -> Tuple[Tensor, Tensor]:
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
(torch.Tensor): the processed point_embedding
(torch.Tensor): the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
@ -112,12 +110,11 @@ class TwoWayAttentionBlock(nn.Module):
skip_first_layer_pe: bool = False,
) -> None:
"""
A transformer block with four layers: (1) self-attention of sparse
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
block on sparse inputs, and (4) cross attention of dense inputs to sparse
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
inputs.
Arguments:
Args:
embedding_dim (int): the channel dimension of the embeddings
num_heads (int): the number of heads in the attention layers
mlp_dim (int): the hidden dimension of the mlp block
@ -175,8 +172,8 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding
after projection to queries, keys, and values.
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""
def __init__(
@ -230,6 +227,4 @@ class Attention(nn.Module):
# Get output
out = attn @ v
out = self._recombine_heads(out)
out = self.out_proj(out)
return out
return self.out_proj(out)