Add docformatter to pre-commit (#5279)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-09 02:25:22 +02:00 committed by GitHub
parent c7aa83da31
commit 7517667a33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
90 changed files with 1396 additions and 497 deletions

View file

@ -32,9 +32,10 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Computes the stability score for a batch of masks. The stability
score is the IoU between the binary masks obtained by thresholding
the predicted mask logits at high and low values.
Computes the stability score for a batch of masks.
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
and low values.
"""
# One mask is always contained inside the other.
# Save memory by preventing unnecessary cast to torch.int64
@ -60,7 +61,11 @@ def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer:
def generate_crop_boxes(im_size: Tuple[int, ...], n_layers: int,
overlap_ratio: float) -> Tuple[List[List[int]], List[int]]:
"""Generates a list of crop boxes of different sizes. Each layer has (2**i)**2 boxes for the ith layer."""
"""
Generates a list of crop boxes of different sizes.
Each layer has (2**i)**2 boxes for the ith layer.
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
@ -145,8 +150,9 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
Calculates boxes in XYXY format around masks.
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:

View file

@ -1,7 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM model interface
"""
"""SAM model interface."""
from pathlib import Path
@ -13,16 +11,16 @@ from .predict import Predictor
class SAM(Model):
"""
SAM model interface.
"""
"""SAM model interface."""
def __init__(self, model='sam_b.pt') -> None:
"""Initializes the SAM model instance with the specified pre-trained model file."""
if model and Path(model).suffix not in ('.pt', '.pth'):
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
super().__init__(model=model, task='segment')
def _load(self, weights: str, task=None):
"""Loads the provided weights into the SAM model."""
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
@ -48,4 +46,5 @@ class SAM(Model):
@property
def task_map(self):
"""Returns a dictionary mapping the 'segment' task to its corresponding 'Predictor'."""
return {'segment': {'predictor': Predictor}}

View file

@ -98,7 +98,11 @@ class MaskDecoder(nn.Module):
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
"""
Predicts masks.
See 'forward' for more details.
"""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)

View file

@ -100,6 +100,9 @@ class ImageEncoderViT(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
and neck.
"""
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
@ -157,8 +160,8 @@ class PromptEncoder(nn.Module):
def get_dense_pe(self) -> torch.Tensor:
"""
Returns the positional encoding used to encode point prompts,
applied to a dense set of points the shape of the image encoding.
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
image encoding.
Returns:
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
@ -204,9 +207,7 @@ class PromptEncoder(nn.Module):
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],
) -> int:
"""
Gets the batch size of the output given the batch size of the input prompts.
"""
"""Gets the batch size of the output given the batch size of the input prompts."""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
@ -217,6 +218,7 @@ class PromptEncoder(nn.Module):
return 1
def _get_device(self) -> torch.device:
"""Returns the device of the first point embedding's weight tensor."""
return self.point_embeddings[0].weight.device
def forward(
@ -259,11 +261,10 @@ class PromptEncoder(nn.Module):
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
"""Positional encoding using random spatial frequencies."""
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
"""Initializes a position embedding using random spatial frequencies."""
super().__init__()
if scale is None or scale <= 0.0:
scale = 1.0
@ -304,7 +305,7 @@ class PositionEmbeddingRandom(nn.Module):
class Block(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
"""Transformer blocks with support of window attention and residual propagation blocks."""
def __init__(
self,
@ -351,6 +352,7 @@ class Block(nn.Module):
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
shortcut = x
x = self.norm1(x)
# Window partition
@ -404,6 +406,7 @@ class Attention(nn.Module):
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@ -448,6 +451,7 @@ def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[in
hw: Tuple[int, int]) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
@ -540,9 +544,7 @@ def add_decomposed_rel_pos(
class PatchEmbed(nn.Module):
"""
Image to Patch Embedding.
"""
"""Image to Patch Embedding."""
def __init__(
self,
@ -565,4 +567,5 @@ class PatchEmbed(nn.Module):
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C

View file

@ -23,6 +23,9 @@ from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
drop path.
"""
super().__init__()
self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
@ -34,6 +37,9 @@ class Conv2d_BN(torch.nn.Sequential):
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, resolution, activation):
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
function.
"""
super().__init__()
img_size: Tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
@ -48,12 +54,16 @@ class PatchEmbed(nn.Module):
)
def forward(self, x):
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
return self.seq(x)
class MBConv(nn.Module):
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
function.
"""
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
@ -73,6 +83,7 @@ class MBConv(nn.Module):
self.drop_path = nn.Identity()
def forward(self, x):
"""Implements the forward pass for the model architecture."""
shortcut = x
x = self.conv1(x)
x = self.act1(x)
@ -87,6 +98,9 @@ class MBConv(nn.Module):
class PatchMerging(nn.Module):
def __init__(self, input_resolution, dim, out_dim, activation):
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
optional parameters.
"""
super().__init__()
self.input_resolution = input_resolution
@ -99,6 +113,7 @@ class PatchMerging(nn.Module):
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x):
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
@ -149,6 +164,7 @@ class ConvLayer(nn.Module):
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
"""Processes the input through a series of convolutional layers and returns the activated output."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
@ -157,6 +173,7 @@ class ConvLayer(nn.Module):
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -167,6 +184,7 @@ class Mlp(nn.Module):
self.drop = nn.Dropout(drop)
def forward(self, x):
"""Applies operations on input x and returns modified x, runs downsample if not None."""
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
@ -216,6 +234,7 @@ class Attention(torch.nn.Module):
@torch.no_grad()
def train(self, mode=True):
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
super().train(mode)
if mode and hasattr(self, 'ab'):
del self.ab
@ -298,6 +317,9 @@ class TinyViTBlock(nn.Module):
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x):
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
convolution.
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
@ -337,6 +359,9 @@ class TinyViTBlock(nn.Module):
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
attentions heads, window size, and MLP ratio.
"""
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, mlp_ratio={self.mlp_ratio}'
@ -402,23 +427,28 @@ class BasicLayer(nn.Module):
input_resolution, dim=dim, out_dim=out_dim, activation=activation)
def forward(self, x):
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
for blk in self.blocks:
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
"""Returns a string representation of the extra_repr function with the layer's parameters."""
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
class LayerNorm2d(nn.Module):
"""A PyTorch implementation of Layer Normalization in 2D."""
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform a forward pass, normalizing the input tensor."""
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
@ -518,6 +548,7 @@ class TinyViT(nn.Module):
)
def set_layer_lr_decay(self, layer_lr_decay):
"""Sets the learning rate decay for each layer in the TinyViT model."""
decay_rate = layer_lr_decay
# layers -> blocks (depth)
@ -525,6 +556,7 @@ class TinyViT(nn.Module):
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
def _set_lr_scale(m, scale):
"""Sets the learning rate scale for each layer in the model based on the layer's depth."""
for p in m.parameters():
p.lr_scale = scale
@ -544,12 +576,14 @@ class TinyViT(nn.Module):
p.param_name = k
def _check_lr_scale(m):
"""Checks if the learning rate scale attribute is present in module's parameters."""
for p in m.parameters():
assert hasattr(p, 'lr_scale'), p.param_name
self.apply(_check_lr_scale)
def _init_weights(self, m):
"""Initializes weights for linear layers and layer normalization in the given module."""
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
@ -561,11 +595,12 @@ class TinyViT(nn.Module):
@torch.jit.ignore
def no_weight_decay_keywords(self):
"""Returns a dictionary of parameter names where weight decay should not be applied."""
return {'attention_biases'}
def forward_features(self, x):
# x: (N, C, H, W)
x = self.patch_embed(x)
"""Runs the input through the model layers and returns the transformed output."""
x = self.patch_embed(x) # x input is (N, C, H, W)
x = self.layers[0](x)
start_i = 1
@ -579,4 +614,5 @@ class TinyViT(nn.Module):
return self.neck(x)
def forward(self, x):
"""Executes a forward pass on the input tensor through the constructed model layers."""
return self.forward_features(x)

View file

@ -21,8 +21,7 @@ class TwoWayTransformer(nn.Module):
attention_downsample_rate: int = 2,
) -> None:
"""
A transformer decoder that attends to an input image using
queries whose positional embedding is supplied.
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
Args:
depth (int): number of layers in the transformer
@ -171,8 +170,7 @@ class TwoWayAttentionBlock(nn.Module):
class Attention(nn.Module):
"""
An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values.
"""

View file

@ -19,6 +19,7 @@ from .build import build_sam
class Predictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initializes the Predictor class with default or provided configuration, overrides, and callbacks."""
if overrides is None:
overrides = {}
overrides.update(dict(task='segment', mode='predict', imgsz=1024))
@ -34,7 +35,8 @@ class Predictor(BasePredictor):
self.segment_all = False
def preprocess(self, im):
"""Prepares input image before inference.
"""
Prepares input image before inference.
Args:
im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
@ -189,7 +191,8 @@ class Predictor(BasePredictor):
stability_score_thresh=0.95,
stability_score_offset=0.95,
crop_nms_thresh=0.7):
"""Segment the whole image.
"""
Segment the whole image.
Args:
im (torch.Tensor): The preprocessed image, (N, C, H, W).
@ -360,14 +363,15 @@ class Predictor(BasePredictor):
self.prompts = prompts
def reset_image(self):
"""Resets the image and its features to None."""
self.im = None
self.features = None
@staticmethod
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates. Requires open-cv as a dependency.
Removes small disconnected regions and holes in masks, then reruns box NMS to remove any new duplicates.
Requires open-cv as a dependency.
Args:
masks (torch.Tensor): Masks, (N, H, W).