Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -103,13 +103,9 @@ class ImageEncoderViT(nn.Module):
|
|||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
return self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
|
|
@ -125,7 +121,7 @@ class PromptEncoder(nn.Module):
|
|||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
|
|
@ -165,8 +161,7 @@ class PromptEncoder(nn.Module):
|
|||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
|
|
@ -231,21 +226,17 @@ class PromptEncoder(nn.Module):
|
|||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Arguments:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates
|
||||
and labels to embed.
|
||||
Args:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
|
||||
boxes (torch.Tensor, None): boxes to embed
|
||||
masks (torch.Tensor, None): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
|
||||
by the number of input points and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
|
|
@ -372,9 +363,7 @@ class Block(nn.Module):
|
|||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
|
@ -427,9 +416,7 @@ class Attention(nn.Module):
|
|||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
|
|
@ -577,7 +564,4 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue