ultralytics 8.3.78 new YOLO12 models (#19325)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2025-02-20 20:42:50 +08:00 committed by GitHub
parent f83d679415
commit 216e6fef58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 674 additions and 42 deletions

View file

@ -1154,3 +1154,205 @@ class TorchVision(nn.Module):
else:
y = self.m(x)
return y
class AAttn(nn.Module):
"""
Area-attention module for YOLO models, providing efficient attention mechanisms.
This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,
making it particularly effective for object detection tasks.
Attributes:
area (int): Number of areas the feature map is divided.
num_heads (int): Number of heads into which the attention mechanism is divided.
head_dim (int): Dimension of each attention head.
qkv (Conv): Convolution layer for computing query, key and value tensors.
proj (Conv): Projection convolution layer.
pe (Conv): Position encoding convolution layer.
Methods:
forward: Applies area-attention to input tensor.
Examples:
>>> attn = AAttn(dim=256, num_heads=8, area=4)
>>> x = torch.randn(1, 256, 32, 32)
>>> output = attn(x)
>>> print(output.shape)
torch.Size([1, 256, 32, 32])
"""
def __init__(self, dim, num_heads, area=1):
"""
Initializes an Area-attention module for YOLO models.
Args:
dim (int): Number of hidden channels.
num_heads (int): Number of heads into which the attention mechanism is divided.
area (int): Number of areas the feature map is divided, default is 1.
"""
super().__init__()
self.area = area
self.num_heads = num_heads
self.head_dim = head_dim = dim // num_heads
all_head_dim = head_dim * self.num_heads
self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
self.proj = Conv(all_head_dim, dim, 1, act=False)
self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
def forward(self, x):
"""Processes the input tensor 'x' through the area-attention."""
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x).flatten(2).transpose(1, 2)
if self.area > 1:
qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
B, N, _ = qkv.shape
q, k, v = (
qkv.view(B, N, self.num_heads, self.head_dim * 3)
.permute(0, 2, 3, 1)
.split([self.head_dim, self.head_dim, self.head_dim], dim=2)
)
attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5)
attn = attn.softmax(dim=-1)
x = v @ attn.transpose(-2, -1)
x = x.permute(0, 3, 1, 2)
v = v.permute(0, 3, 1, 2)
if self.area > 1:
x = x.reshape(B // self.area, N * self.area, C)
v = v.reshape(B // self.area, N * self.area, C)
B, N, _ = x.shape
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = x + self.pe(v)
return self.proj(x)
class ABlock(nn.Module):
"""
Area-attention block module for efficient feature extraction in YOLO models.
This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.
It uses a novel area-based attention approach that is more efficient than traditional self-attention while
maintaining effectiveness.
Attributes:
attn (AAttn): Area-attention module for processing spatial features.
mlp (nn.Sequential): Multi-layer perceptron for feature transformation.
Methods:
_init_weights: Initializes module weights using truncated normal distribution.
forward: Applies area-attention and feed-forward processing to input tensor.
Examples:
>>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)
>>> x = torch.randn(1, 256, 32, 32)
>>> output = block(x)
>>> print(output.shape)
torch.Size([1, 256, 32, 32])
"""
def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1):
"""
Initializes an Area-attention block module for efficient feature extraction in YOLO models.
This module implements an area-attention mechanism combined with a feed-forward network for processing feature
maps. It uses a novel area-based attention approach that is more efficient than traditional self-attention
while maintaining effectiveness.
Args:
dim (int): Number of input channels.
num_heads (int): Number of heads into which the attention mechanism is divided.
mlp_ratio (float): Expansion ratio for MLP hidden dimension.
area (int): Number of areas the feature map is divided.
"""
super().__init__()
self.attn = AAttn(dim, num_heads=num_heads, area=area)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights using a truncated normal distribution."""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
x = x + self.attn(x)
return x + self.mlp(x)
class A2C2f(nn.Module):
"""
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature
processing. It supports both area-attention and standard convolution modes.
Attributes:
cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.
cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.
gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.
m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.
Methods:
forward: Processes input through area-attention or standard convolution pathway.
Examples:
>>> m = A2C2f(512, 512, n=1, a2=True, area=1)
>>> x = torch.randn(1, 512, 32, 32)
>>> output = m(x)
>>> print(output.shape)
torch.Size([1, 512, 32, 32])
"""
def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
"""
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
n (int): Number of ABlock or C3k modules to stack.
a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead.
area (int): Number of areas the feature map is divided.
residual (bool): Whether to use residual connections with learnable gamma parameter.
mlp_ratio (float): Expansion ratio for MLP hidden dimension.
e (float): Channel expansion ratio for hidden channels.
g (int): Number of groups for grouped convolutions.
shortcut (bool): Whether to use shortcut connections in C3k blocks.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv((1 + n) * c_, c2, 1)
self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None
self.m = nn.ModuleList(
nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2)))
if a2
else C3k(c_, c_, 2, shortcut, g)
for _ in range(n)
)
def forward(self, x):
"""Forward pass through R-ELAN layer."""
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in self.m)
y = self.cv2(torch.cat(y, 1))
if self.gamma is not None:
return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y
return y