Tests and docstrings improvements (#4475)

This commit is contained in:
Glenn Jocher 2023-08-21 17:02:14 +02:00 committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 107 additions and 186 deletions

View file

@ -103,13 +103,9 @@ class ImageEncoderViT(nn.Module):
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x
return self.neck(x.permute(0, 3, 1, 2))
class PromptEncoder(nn.Module):
@ -125,7 +121,7 @@ class PromptEncoder(nn.Module):
"""
Encodes prompts for input to SAM's mask decoder.
Arguments:
Args:
embed_dim (int): The prompts' embedding dimension
image_embedding_size (tuple(int, int)): The spatial size of the
image embedding, as (H, W).
@ -165,8 +161,7 @@ class PromptEncoder(nn.Module):
applied to a dense set of points the shape of the image encoding.
Returns:
torch.Tensor: Positional encoding with shape
1x(embed_dim)x(embedding_h)x(embedding_w)
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
@ -231,21 +226,17 @@ class PromptEncoder(nn.Module):
masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Embeds different types of prompts, returning both sparse and dense
embeddings.
Embeds different types of prompts, returning both sparse and dense embeddings.
Arguments:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates
and labels to embed.
Args:
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
boxes (torch.Tensor, None): boxes to embed
masks (torch.Tensor, None): masks to embed
Returns:
torch.Tensor: sparse embeddings for the points and boxes, with shape
BxNx(embed_dim), where N is determined by the number of input points
and boxes.
torch.Tensor: dense embeddings for the masks, in the shape
Bx(embed_dim)x(embed_H)x(embed_W)
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
by the number of input points and boxes.
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
@ -372,9 +363,7 @@ class Block(nn.Module):
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
return x + self.mlp(self.norm2(x))
class Attention(nn.Module):
@ -427,9 +416,7 @@ class Attention(nn.Module):
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
return self.proj(x)
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
@ -577,7 +564,4 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C H W -> B H W C
x = x.permute(0, 2, 3, 1)
return x
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C