ultralytics 8.3.0 YOLO11 Models Release (#16539)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
efb0c17881
commit
6e43d1e1e5
50 changed files with 1154 additions and 407 deletions
|
|
@ -20,6 +20,7 @@ Example:
|
|||
from .block import (
|
||||
C1,
|
||||
C2,
|
||||
C2PSA,
|
||||
C3,
|
||||
C3TR,
|
||||
CIB,
|
||||
|
|
@ -38,7 +39,9 @@ from .block import (
|
|||
C2f,
|
||||
C2fAttn,
|
||||
C2fCIB,
|
||||
C2fPSA,
|
||||
C3Ghost,
|
||||
C3k2,
|
||||
C3x,
|
||||
CBFuse,
|
||||
CBLinear,
|
||||
|
|
@ -110,6 +113,10 @@ __all__ = (
|
|||
"C2",
|
||||
"C3",
|
||||
"C2f",
|
||||
"C3k2",
|
||||
"SCDown",
|
||||
"C2fPSA",
|
||||
"C2PSA",
|
||||
"C2fAttn",
|
||||
"C3x",
|
||||
"C3TR",
|
||||
|
|
@ -149,5 +156,4 @@ __all__ = (
|
|||
"C2fCIB",
|
||||
"Attention",
|
||||
"PSA",
|
||||
"SCDown",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,9 @@ __all__ = (
|
|||
"SPPELAN",
|
||||
"CBFuse",
|
||||
"CBLinear",
|
||||
"C3k2",
|
||||
"C2fPSA",
|
||||
"C2PSA",
|
||||
"RepVGGDW",
|
||||
"CIB",
|
||||
"C2fCIB",
|
||||
|
|
@ -696,6 +699,49 @@ class CBFuse(nn.Module):
|
|||
return torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
||||
|
||||
|
||||
class C3f(nn.Module):
|
||||
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
|
||||
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
|
||||
expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, c_, 1, 1)
|
||||
self.cv2 = Conv(c1, c_, 1, 1)
|
||||
self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
|
||||
self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through C2f layer."""
|
||||
y = [self.cv2(x), self.cv1(x)]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
return self.cv3(torch.cat(y, 1))
|
||||
|
||||
|
||||
class C3k2(C2f):
|
||||
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True):
|
||||
"""Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks."""
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
self.m = nn.ModuleList(
|
||||
C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n)
|
||||
)
|
||||
|
||||
|
||||
class C3k(C3):
|
||||
"""C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
|
||||
"""Initializes the C3k module with specified channels, number of layers, and configurations."""
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
# self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||||
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
|
||||
|
||||
|
||||
class RepVGGDW(torch.nn.Module):
|
||||
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
|
||||
|
||||
|
|
@ -873,25 +919,69 @@ class Attention(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class PSA(nn.Module):
|
||||
class PSABlock(nn.Module):
|
||||
"""
|
||||
Position-wise Spatial Attention module.
|
||||
PSABlock class implementing a Position-Sensitive Attention block for neural networks.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
e (float): Expansion factor for the intermediate channels. Default is 0.5.
|
||||
This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers
|
||||
with optional shortcut connections.
|
||||
|
||||
Attributes:
|
||||
c (int): Number of intermediate channels.
|
||||
attn (Attention): Multi-head attention module.
|
||||
ffn (nn.Sequential): Feed-forward neural network module.
|
||||
add (bool): Flag indicating whether to add shortcut connections.
|
||||
|
||||
Methods:
|
||||
forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers.
|
||||
|
||||
Examples:
|
||||
Create a PSABlock and perform a forward pass
|
||||
>>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True)
|
||||
>>> input_tensor = torch.randn(1, 128, 32, 32)
|
||||
>>> output_tensor = psablock(input_tensor)
|
||||
"""
|
||||
|
||||
def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None:
|
||||
"""Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction."""
|
||||
super().__init__()
|
||||
|
||||
self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads)
|
||||
self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False))
|
||||
self.add = shortcut
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor."""
|
||||
x = x + self.attn(x) if self.add else self.attn(x)
|
||||
x = x + self.ffn(x) if self.add else self.ffn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PSA(nn.Module):
|
||||
"""
|
||||
PSA class for implementing Position-Sensitive Attention in neural networks.
|
||||
|
||||
This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to
|
||||
input tensors, enhancing feature extraction and processing capabilities.
|
||||
|
||||
Attributes:
|
||||
c (int): Number of hidden channels after applying the initial convolution.
|
||||
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
|
||||
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
|
||||
attn (Attention): Attention module for spatial attention.
|
||||
ffn (nn.Sequential): Feed-forward network module.
|
||||
attn (Attention): Attention module for position-sensitive attention.
|
||||
ffn (nn.Sequential): Feed-forward network for further processing.
|
||||
|
||||
Methods:
|
||||
forward: Applies position-sensitive attention and feed-forward network to the input tensor.
|
||||
|
||||
Examples:
|
||||
Create a PSA module and apply it to an input tensor
|
||||
>>> psa = PSA(c1=128, c2=128, e=0.5)
|
||||
>>> input_tensor = torch.randn(1, 128, 64, 64)
|
||||
>>> output_tensor = psa.forward(input_tensor)
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, e=0.5):
|
||||
"""Initializes convolution layers, attention module, and feed-forward network with channel reduction."""
|
||||
"""Initializes the PSA module with input/output channels and attention mechanism for feature extraction."""
|
||||
super().__init__()
|
||||
assert c1 == c2
|
||||
self.c = int(c1 * e)
|
||||
|
|
@ -902,46 +992,117 @@ class PSA(nn.Module):
|
|||
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the PSA module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor.
|
||||
"""
|
||||
"""Executes forward pass in PSA module, applying attention and feed-forward layers to the input tensor."""
|
||||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||||
b = b + self.attn(b)
|
||||
b = b + self.ffn(b)
|
||||
return self.cv2(torch.cat((a, b), 1))
|
||||
|
||||
|
||||
class C2PSA(nn.Module):
|
||||
"""
|
||||
C2PSA module with attention mechanism for enhanced feature extraction and processing.
|
||||
|
||||
This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing
|
||||
capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations.
|
||||
|
||||
Attributes:
|
||||
c (int): Number of hidden channels.
|
||||
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
|
||||
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
|
||||
m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations.
|
||||
|
||||
Methods:
|
||||
forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations.
|
||||
|
||||
Notes:
|
||||
This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules.
|
||||
|
||||
Examples:
|
||||
>>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5)
|
||||
>>> input_tensor = torch.randn(1, 256, 64, 64)
|
||||
>>> output_tensor = c2psa(input_tensor)
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, n=1, e=0.5):
|
||||
"""Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio."""
|
||||
super().__init__()
|
||||
assert c1 == c2
|
||||
self.c = int(c1 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||||
|
||||
self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)))
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor."""
|
||||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||||
b = self.m(b)
|
||||
return self.cv2(torch.cat((a, b), 1))
|
||||
|
||||
|
||||
class C2fPSA(C2f):
|
||||
"""
|
||||
C2fPSA module with enhanced feature extraction using PSA blocks.
|
||||
|
||||
This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction.
|
||||
|
||||
Attributes:
|
||||
c (int): Number of hidden channels.
|
||||
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
|
||||
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
|
||||
m (nn.ModuleList): List of PSA blocks for feature extraction.
|
||||
|
||||
Methods:
|
||||
forward: Performs a forward pass through the C2fPSA module.
|
||||
forward_split: Performs a forward pass using split() instead of chunk().
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from ultralytics.models.common import C2fPSA
|
||||
>>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5)
|
||||
>>> x = torch.randn(1, 64, 128, 128)
|
||||
>>> output = model(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, n=1, e=0.5):
|
||||
"""Initializes the C2fPSA module, a variant of C2f with PSA blocks for enhanced feature extraction."""
|
||||
assert c1 == c2
|
||||
super().__init__(c1, c2, n=n, e=e)
|
||||
self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))
|
||||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
"""Spatial Channel Downsample (SCDown) module for reducing spatial and channel dimensions."""
|
||||
"""
|
||||
SCDown module for downsampling with separable convolutions.
|
||||
|
||||
This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in
|
||||
efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information.
|
||||
|
||||
Attributes:
|
||||
cv1 (Conv): Pointwise convolution layer that reduces the number of channels.
|
||||
cv2 (Conv): Depthwise convolution layer that performs spatial downsampling.
|
||||
|
||||
Methods:
|
||||
forward: Applies the SCDown module to the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from ultralytics import SCDown
|
||||
>>> model = SCDown(c1=64, c2=128, k=3, s=2)
|
||||
>>> x = torch.randn(1, 64, 128, 128)
|
||||
>>> y = model(x)
|
||||
>>> print(y.shape)
|
||||
torch.Size([1, 128, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, k, s):
|
||||
"""
|
||||
Spatial Channel Downsample (SCDown) module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
k (int): Kernel size for the convolutional layer.
|
||||
s (int): Stride for the convolutional layer.
|
||||
"""
|
||||
"""Initializes the SCDown module with specified input/output channels, kernel size, and stride."""
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c2, 1, 1)
|
||||
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the SCDown module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor after applying the SCDown module.
|
||||
"""
|
||||
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
||||
return self.cv2(self.cv1(x))
|
||||
|
|
|
|||
|
|
@ -209,7 +209,8 @@ class RepConv(nn.Module):
|
|||
kernelid, biasid = self._fuse_bn_tensor(self.bn)
|
||||
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
||||
|
||||
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
||||
@staticmethod
|
||||
def _pad_1x1_to_3x3_tensor(kernel1x1):
|
||||
"""Pads a 1x1 tensor to a 3x3 tensor."""
|
||||
if kernel1x1 is None:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from torch.nn.init import constant_, xavier_uniform_
|
|||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
|
||||
|
||||
from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto
|
||||
from .conv import Conv
|
||||
from .conv import Conv, DWConv
|
||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||
from .utils import bias_init_with_prob, linear_init
|
||||
|
||||
|
|
@ -41,7 +41,14 @@ class Detect(nn.Module):
|
|||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||
)
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||||
self.cv3 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, self.nc, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
if self.end2end:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from ultralytics.nn.modules import (
|
|||
AIFI,
|
||||
C1,
|
||||
C2,
|
||||
C2PSA,
|
||||
C3,
|
||||
C3TR,
|
||||
ELAN1,
|
||||
|
|
@ -28,7 +29,9 @@ from ultralytics.nn.modules import (
|
|||
C2f,
|
||||
C2fAttn,
|
||||
C2fCIB,
|
||||
C2fPSA,
|
||||
C3Ghost,
|
||||
C3k2,
|
||||
C3x,
|
||||
CBFuse,
|
||||
CBLinear,
|
||||
|
|
@ -968,12 +971,15 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
GhostBottleneck,
|
||||
SPP,
|
||||
SPPF,
|
||||
C2fPSA,
|
||||
C2PSA,
|
||||
DWConv,
|
||||
Focus,
|
||||
BottleneckCSP,
|
||||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C3k2,
|
||||
RepNCSPELAN4,
|
||||
ELAN1,
|
||||
ADown,
|
||||
|
|
@ -1001,9 +1007,26 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
) # num heads
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB}:
|
||||
if m in {
|
||||
BottleneckCSP,
|
||||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C3k2,
|
||||
C2fAttn,
|
||||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
RepC3,
|
||||
C2fPSA,
|
||||
C2fCIB,
|
||||
C2PSA,
|
||||
}:
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
if m is C3k2 and scale in "mlx": # for M/L/X sizes
|
||||
args[3] = True
|
||||
elif m is AIFI:
|
||||
args = [ch[f], *args]
|
||||
elif m in {HGStem, HGBlock}:
|
||||
|
|
@ -1080,7 +1103,7 @@ def guess_model_scale(model_path):
|
|||
with contextlib.suppress(AttributeError):
|
||||
import re
|
||||
|
||||
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue