Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -24,7 +24,7 @@ class MaskDecoder(nn.Module):
|
|||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a transformer architecture.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
transformer_dim (int): the channel dimension of the transformer module
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
|
||||
|
|
@ -65,7 +65,7 @@ class MaskDecoder(nn.Module):
|
|||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
|
|
|
|||
|
|
@ -103,13 +103,9 @@ class ImageEncoderViT(nn.Module):
|
|||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
return x
|
||||
return self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
|
|
@ -125,7 +121,7 @@ class PromptEncoder(nn.Module):
|
|||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
|
|
@ -165,8 +161,7 @@ class PromptEncoder(nn.Module):
|
|||
applied to a dense set of points the shape of the image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape
|
||||
1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
|
|
@ -231,21 +226,17 @@ class PromptEncoder(nn.Module):
|
|||
masks: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embeds different types of prompts, returning both sparse and dense
|
||||
embeddings.
|
||||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Arguments:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates
|
||||
and labels to embed.
|
||||
Args:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
|
||||
boxes (torch.Tensor, None): boxes to embed
|
||||
masks (torch.Tensor, None): masks to embed
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
||||
BxNx(embed_dim), where N is determined by the number of input points
|
||||
and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape
|
||||
Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
|
||||
by the number of input points and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
|
|
@ -372,9 +363,7 @@ class Block(nn.Module):
|
|||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
|
@ -427,9 +416,7 @@ class Attention(nn.Module):
|
|||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
|
|
@ -577,7 +564,4 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
# B C H W -> B H W C
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
return x
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class Sam(nn.Module):
|
|||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
|
|
@ -60,14 +60,12 @@ class Sam(nn.Module):
|
|||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
If prompts are not known in advance, using SamPredictor is
|
||||
recommended over calling the model directly.
|
||||
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
|
||||
SamPredictor is recommended over calling the model directly.
|
||||
|
||||
Arguments:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
Args:
|
||||
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
||||
key can be excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
|
|
@ -81,12 +79,11 @@ class Sam(nn.Module):
|
|||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
|
|
@ -139,7 +136,7 @@ class Sam(nn.Module):
|
|||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
|
|
@ -158,8 +155,7 @@ class Sam(nn.Module):
|
|||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., :input_size[0], :input_size[1]]
|
||||
masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
|
||||
return masks
|
||||
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
|
|
|
|||
|
|
@ -35,8 +35,7 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
||||
m = torch.nn.Conv2d(w.size(1) * self.c.groups,
|
||||
w.size(0),
|
||||
w.shape[2:],
|
||||
|
|
@ -72,8 +71,7 @@ class PatchEmbed(nn.Module):
|
|||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
self.num_patches = self.patches_resolution[0] * \
|
||||
self.patches_resolution[1]
|
||||
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
n = embed_dim
|
||||
|
|
@ -110,21 +108,14 @@ class MBConv(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
|
||||
x = self.conv3(x)
|
||||
|
||||
x = self.drop_path(x)
|
||||
|
||||
x += shortcut
|
||||
x = self.act3(x)
|
||||
|
||||
return x
|
||||
return self.act3(x)
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
|
|
@ -137,9 +128,7 @@ class PatchMerging(nn.Module):
|
|||
self.out_dim = out_dim
|
||||
self.act = activation()
|
||||
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
||||
stride_c = 2
|
||||
if (out_dim == 320 or out_dim == 448 or out_dim == 576):
|
||||
stride_c = 1
|
||||
stride_c = 1 if out_dim in [320, 448, 576] else 2
|
||||
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
|
|
@ -156,8 +145,7 @@ class PatchMerging(nn.Module):
|
|||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv3(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
return x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
|
|
@ -174,7 +162,6 @@ class ConvLayer(nn.Module):
|
|||
out_dim=None,
|
||||
conv_expand_ratio=4.,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -192,20 +179,13 @@ class ConvLayer(nn.Module):
|
|||
) for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
|
@ -222,13 +202,11 @@ class Mlp(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
return self.drop(x)
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
|
@ -297,12 +275,12 @@ class Attention(torch.nn.Module):
|
|||
(self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab))
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
r""" TinyViT Block.
|
||||
"""
|
||||
TinyViT Block.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
|
|
@ -312,8 +290,7 @@ class TinyViTBlock(nn.Module):
|
|||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
local_conv_size (int): the kernel size of the convolution between
|
||||
Attention and MLP. Default: 3
|
||||
local_conv_size (int): the kernel size of the convolution between Attention and MLP. Default: 3
|
||||
activation (torch.nn): the activation function. Default: nn.GELU
|
||||
"""
|
||||
|
||||
|
|
@ -391,8 +368,7 @@ class TinyViTBlock(nn.Module):
|
|||
x = self.local_conv(x)
|
||||
x = x.view(B, C, L).transpose(1, 2)
|
||||
|
||||
x = x + self.drop_path(self.mlp(x))
|
||||
return x
|
||||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
|
|
@ -400,7 +376,8 @@ class TinyViTBlock(nn.Module):
|
|||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic TinyViT layer for one stage.
|
||||
"""
|
||||
A basic TinyViT layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
|
|
@ -434,7 +411,6 @@ class BasicLayer(nn.Module):
|
|||
activation=nn.GELU,
|
||||
out_dim=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -456,20 +432,13 @@ class BasicLayer(nn.Module):
|
|||
) for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
else:
|
||||
self.downsample = None
|
||||
self.downsample = None if downsample is None else downsample(
|
||||
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
|
@ -487,8 +456,7 @@ class LayerNorm2d(nn.Module):
|
|||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
return x
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
|
|
@ -548,10 +516,7 @@ class TinyViT(nn.Module):
|
|||
activation=activation,
|
||||
)
|
||||
if i_layer == 0:
|
||||
layer = ConvLayer(
|
||||
conv_expand_ratio=mbconv_expand_ratio,
|
||||
**kwargs,
|
||||
)
|
||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||
else:
|
||||
layer = BasicLayer(num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
|
|
@ -622,7 +587,7 @@ class TinyViT(nn.Module):
|
|||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
|
@ -645,9 +610,7 @@ class TinyViT(nn.Module):
|
|||
B, _, C = x.size()
|
||||
x = x.view(B, 64, 64, C)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = self.neck(x)
|
||||
return x
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return x
|
||||
return self.forward_features(x)
|
||||
|
|
|
|||
|
|
@ -61,16 +61,14 @@ class TwoWayTransformer(nn.Module):
|
|||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape
|
||||
B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
||||
have the same shape as image_embedding.
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the processed point_embedding
|
||||
torch.Tensor: the processed image_embedding
|
||||
(torch.Tensor): the processed point_embedding
|
||||
(torch.Tensor): the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
|
|
@ -112,12 +110,11 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse
|
||||
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
||||
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
|
||||
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
|
|
@ -175,8 +172,8 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer that allows for downscaling the size of the embedding
|
||||
after projection to queries, keys, and values.
|
||||
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -230,6 +227,4 @@ class Attention(nn.Module):
|
|||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
return self.out_proj(out)
|
||||
|
|
|
|||
|
|
@ -145,9 +145,11 @@ class SegmentationValidator(DetectionValidator):
|
|||
def _process_batch(self, detections, labels, pred_masks=None, gt_masks=None, overlap=False, masks=False):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
|
||||
Args:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue