Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
|
@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts,
|
||||
applied to a dense set of points the shape of the image encoding.
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
|
|
@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
|
|||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""
|
||||
Gets the batch size of the output given the batch size of the input prompts.
|
||||
"""
|
||||
"""Gets the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
|
|
@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
|
|||
return 1
|
||||
|
||||
def _get_device(self) -> torch.device:
|
||||
"""Returns the device of the first point embedding's weight tensor."""
|
||||
return self.point_embeddings[0].weight.device
|
||||
|
||||
def forward(
|
||||
|
|
@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""
|
||||
Positional encoding using random spatial frequencies.
|
||||
"""
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
|
|
@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -351,6 +352,7 @@ class Block(nn.Module):
|
|||
self.window_size = window_size
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
|
|
@ -404,6 +406,7 @@ class Attention(nn.Module):
|
|||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
|
|
@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
|
|||
hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
|
|
@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Image to Patch Embedding.
|
||||
"""
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue