PyCharm Code and Docs Inspect fixes v1 (#18461)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
126867e355
commit
7f1a50e893
26 changed files with 90 additions and 91 deletions
|
|
@ -502,11 +502,11 @@ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.T
|
|||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
"""
|
||||
Implements multi-scale self-attention with optional query pooling for efficient feature extraction.
|
||||
Implements multiscale self-attention with optional query pooling for efficient feature extraction.
|
||||
|
||||
This class provides a flexible implementation of multi-scale attention, allowing for optional
|
||||
This class provides a flexible implementation of multiscale attention, allowing for optional
|
||||
downsampling of query features through pooling. It's designed to enhance the model's ability to
|
||||
capture multi-scale information in visual tasks.
|
||||
capture multiscale information in visual tasks.
|
||||
|
||||
Attributes:
|
||||
dim (int): Input dimension of the feature map.
|
||||
|
|
@ -518,7 +518,7 @@ class MultiScaleAttention(nn.Module):
|
|||
proj (nn.Linear): Output projection.
|
||||
|
||||
Methods:
|
||||
forward: Applies multi-scale attention to the input tensor.
|
||||
forward: Applies multiscale attention to the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
|
|
@ -537,7 +537,7 @@ class MultiScaleAttention(nn.Module):
|
|||
num_heads: int,
|
||||
q_pool: nn.Module = None,
|
||||
):
|
||||
"""Initializes multi-scale attention with optional query pooling for efficient feature extraction."""
|
||||
"""Initializes multiscale attention with optional query pooling for efficient feature extraction."""
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
|
|
@ -552,7 +552,7 @@ class MultiScaleAttention(nn.Module):
|
|||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies multi-scale attention with optional query pooling to extract multi-scale features."""
|
||||
"""Applies multiscale attention with optional query pooling to extract multiscale features."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (B, H * W, 3, nHead, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
||||
|
|
@ -582,9 +582,9 @@ class MultiScaleAttention(nn.Module):
|
|||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
"""
|
||||
A multi-scale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
A multiscale attention block with window partitioning and query pooling for efficient vision transformers.
|
||||
|
||||
This class implements a multi-scale attention mechanism with optional window partitioning and downsampling,
|
||||
This class implements a multiscale attention mechanism with optional window partitioning and downsampling,
|
||||
designed for use in vision transformer architectures.
|
||||
|
||||
Attributes:
|
||||
|
|
@ -601,7 +601,7 @@ class MultiScaleBlock(nn.Module):
|
|||
proj (nn.Linear | None): Projection layer for dimension mismatch.
|
||||
|
||||
Methods:
|
||||
forward: Processes input tensor through the multi-scale block.
|
||||
forward: Processes input tensor through the multiscale block.
|
||||
|
||||
Examples:
|
||||
>>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7)
|
||||
|
|
@ -623,7 +623,7 @@ class MultiScaleBlock(nn.Module):
|
|||
act_layer: nn.Module = nn.GELU,
|
||||
window_size: int = 0,
|
||||
):
|
||||
"""Initializes a multi-scale attention block with window partitioning and optional query pooling."""
|
||||
"""Initializes a multiscale attention block with window partitioning and optional query pooling."""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
|
|
@ -660,7 +660,7 @@ class MultiScaleBlock(nn.Module):
|
|||
self.proj = nn.Linear(dim, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through multi-scale attention and MLP, with optional windowing and downsampling."""
|
||||
"""Processes input through multiscale attention and MLP, with optional windowing and downsampling."""
|
||||
shortcut = x # B, H, W, C
|
||||
x = self.norm1(x)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue