Add docformatter to pre-commit (#5279)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-09 02:25:22 +02:00 committed by GitHub
parent c7aa83da31
commit 7517667a33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 1396 additions and 497 deletions

View file

@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
and neck.
"""
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
image encoding.
Returns:
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
"""Gets the batch size of the output given the batch size of the input prompts."""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
return 1
def _get_device(self) -> torch.device:
"""Returns the device of the first point embedding's weight tensor."""
return self.point_embeddings[0].weight.device
def forward(
@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
"""Positional encoding using random spatial frequencies."""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
"""Initializes a position embedding using random spatial frequencies."""
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
"""Transformer blocks with support of window attention and residual propagation blocks."""
def __init__(
self,
@ -351,6 +352,7 @@ class Block(nn.Module):
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
shortcut = x
x = self.norm1(x)
# Window partition
@ -404,6 +406,7 @@ class Attention(nn.Module):
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
hw: Tuple[int, int]) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
"""Image to Patch Embedding."""
def __init__(
self,
@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C