ultralytics 8.3.78 new YOLO12 models (#19325)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2025-02-20 20:42:50 +08:00 committed by GitHub
parent f83d679415
commit 216e6fef58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 674 additions and 42 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.77"
__version__ = "8.3.78"
import os

View file

@ -0,0 +1,32 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12-cls image classification model
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/classify
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo12n-cls.yaml' will call yolo12-cls.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 152 layers, 1,820,976 parameters, 1,820,976 gradients, 3.7 GFLOPs
s: [0.50, 0.50, 1024] # summary: 152 layers, 6,206,992 parameters, 6,206,992 gradients, 13.6 GFLOPs
m: [0.50, 1.00, 512] # summary: 172 layers, 12,083,088 parameters, 12,083,088 gradients, 44.2 GFLOPs
l: [1.00, 1.00, 512] # summary: 312 layers, 15,558,640 parameters, 15,558,640 gradients, 56.9 GFLOPs
x: [1.00, 1.50, 512] # summary: 312 layers, 34,172,592 parameters, 34,172,592 gradients, 126.5 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, Classify, [nc]] # Classify

View file

@ -0,0 +1,48 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/obb
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo12n-obb.yaml' will call yolo12-obb.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 287 layers, 2,673,955 parameters, 2,673,939 gradients, 6.9 GFLOPs
s: [0.50, 0.50, 1024] # summary: 287 layers, 9,570,275 parameters, 9,570,259 gradients, 22.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 307 layers, 21,048,003 parameters, 21,047,987 gradients, 71.8 GFLOPs
l: [1.00, 1.00, 512] # summary: 503 layers, 27,299,619 parameters, 27,299,603 gradients, 93.4 GFLOPs
x: [1.00, 1.50, 512] # summary: 503 layers, 61,119,939 parameters, 61,119,923 gradients, 208.6 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, OBB, [nc, 1]] # Detect(P3, P4, P5)

View file

@ -0,0 +1,49 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12-pose keypoints/pose estimation model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/pose
# Parameters
nc: 80 # number of classes
kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
scales: # model compound scaling constants, i.e. 'model=yolo12n-pose.yaml' will call yolo12-pose.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 287 layers, 2,886,715 parameters, 2,886,699 gradients, 7.8 GFLOPs
s: [0.50, 0.50, 1024] # summary: 287 layers, 9,774,155 parameters, 9,774,139 gradients, 23.5 GFLOPs
m: [0.50, 1.00, 512] # summary: 307 layers, 21,057,753 parameters, 21,057,737 gradients, 71.8 GFLOPs
l: [1.00, 1.00, 512] # summary: 503 layers, 27,309,369 parameters, 27,309,353 gradients, 93.5 GFLOPs
x: [1.00, 1.50, 512] # summary: 503 layers, 61,134,489 parameters, 61,134,473 gradients, 208.7 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Pose, [nc, kpt_shape]] # Detect(P3, P4, P5)

View file

@ -0,0 +1,48 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12-seg instance segmentation model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/segment
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo12n-seg.yaml' will call yolo12-seg.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 294 layers, 2,855,056 parameters, 2,855,040 gradients, 10.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 294 layers, 9,938,592 parameters, 9,938,576 gradients, 35.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 314 layers, 22,505,376 parameters, 22,505,360 gradients, 123.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 510 layers, 28,756,992 parameters, 28,756,976 gradients, 145.1 GFLOPs
x: [1.00, 1.50, 512] # summary: 510 layers, 64,387,264 parameters, 64,387,248 gradients, 324.6 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5)

View file

@ -0,0 +1,48 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# YOLO12 object detection model with P3/8 - P5/32 outputs
# Model docs: https://docs.ultralytics.com/models/yolo12
# Task docs: https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo12n.yaml' will call yolo12.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 272 layers, 2,602,288 parameters, 2,602,272 gradients, 6.7 GFLOPs
s: [0.50, 0.50, 1024] # summary: 272 layers, 9,284,096 parameters, 9,284,080 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 292 layers, 20,199,168 parameters, 20,199,152 gradients, 68.1 GFLOPs
l: [1.00, 1.00, 512] # summary: 488 layers, 26,450,784 parameters, 26,450,768 gradients, 89.7 GFLOPs
x: [1.00, 1.50, 512] # summary: 488 layers, 59,210,784 parameters, 59,210,768 gradients, 200.3 GFLOPs
# YOLO12n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 4, A2C2f, [512, True, 4]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 4, A2C2f, [1024, True, 1]] # 8
# YOLO12n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, A2C2f, [512, False, -1]] # 11
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, A2C2f, [256, False, -1]] # 14
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 11], 1, Concat, [1]] # cat head P4
- [-1, 2, A2C2f, [512, False, -1]] # 17
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 8], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large)
- [[14, 17, 20], 1, Detect, [nc]] # Detect(P3, P4, P5)

View file

@ -30,6 +30,7 @@ from .block import (
SPP,
SPPELAN,
SPPF,
A2C2f,
AConv,
ADown,
Attention,
@ -160,4 +161,5 @@ __all__ = (
"PSA",
"TorchVision",
"Index",
"A2C2f",
)

View file

@ -1154,3 +1154,205 @@ class TorchVision(nn.Module):
else:
y = self.m(x)
return y
class AAttn(nn.Module):
"""
Area-attention module for YOLO models, providing efficient attention mechanisms.
This module implements an area-based attention mechanism that processes input features in a spatially-aware manner,
making it particularly effective for object detection tasks.
Attributes:
area (int): Number of areas the feature map is divided.
num_heads (int): Number of heads into which the attention mechanism is divided.
head_dim (int): Dimension of each attention head.
qkv (Conv): Convolution layer for computing query, key and value tensors.
proj (Conv): Projection convolution layer.
pe (Conv): Position encoding convolution layer.
Methods:
forward: Applies area-attention to input tensor.
Examples:
>>> attn = AAttn(dim=256, num_heads=8, area=4)
>>> x = torch.randn(1, 256, 32, 32)
>>> output = attn(x)
>>> print(output.shape)
torch.Size([1, 256, 32, 32])
"""
def __init__(self, dim, num_heads, area=1):
"""
Initializes an Area-attention module for YOLO models.
Args:
dim (int): Number of hidden channels.
num_heads (int): Number of heads into which the attention mechanism is divided.
area (int): Number of areas the feature map is divided, default is 1.
"""
super().__init__()
self.area = area
self.num_heads = num_heads
self.head_dim = head_dim = dim // num_heads
all_head_dim = head_dim * self.num_heads
self.qkv = Conv(dim, all_head_dim * 3, 1, act=False)
self.proj = Conv(all_head_dim, dim, 1, act=False)
self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False)
def forward(self, x):
"""Processes the input tensor 'x' through the area-attention."""
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x).flatten(2).transpose(1, 2)
if self.area > 1:
qkv = qkv.reshape(B * self.area, N // self.area, C * 3)
B, N, _ = qkv.shape
q, k, v = (
qkv.view(B, N, self.num_heads, self.head_dim * 3)
.permute(0, 2, 3, 1)
.split([self.head_dim, self.head_dim, self.head_dim], dim=2)
)
attn = (q.transpose(-2, -1) @ k) * (self.head_dim**-0.5)
attn = attn.softmax(dim=-1)
x = v @ attn.transpose(-2, -1)
x = x.permute(0, 3, 1, 2)
v = v.permute(0, 3, 1, 2)
if self.area > 1:
x = x.reshape(B // self.area, N * self.area, C)
v = v.reshape(B // self.area, N * self.area, C)
B, N, _ = x.shape
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
v = v.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = x + self.pe(v)
return self.proj(x)
class ABlock(nn.Module):
"""
Area-attention block module for efficient feature extraction in YOLO models.
This module implements an area-attention mechanism combined with a feed-forward network for processing feature maps.
It uses a novel area-based attention approach that is more efficient than traditional self-attention while
maintaining effectiveness.
Attributes:
attn (AAttn): Area-attention module for processing spatial features.
mlp (nn.Sequential): Multi-layer perceptron for feature transformation.
Methods:
_init_weights: Initializes module weights using truncated normal distribution.
forward: Applies area-attention and feed-forward processing to input tensor.
Examples:
>>> block = ABlock(dim=256, num_heads=8, mlp_ratio=1.2, area=1)
>>> x = torch.randn(1, 256, 32, 32)
>>> output = block(x)
>>> print(output.shape)
torch.Size([1, 256, 32, 32])
"""
def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1):
"""
Initializes an Area-attention block module for efficient feature extraction in YOLO models.
This module implements an area-attention mechanism combined with a feed-forward network for processing feature
maps. It uses a novel area-based attention approach that is more efficient than traditional self-attention
while maintaining effectiveness.
Args:
dim (int): Number of input channels.
num_heads (int): Number of heads into which the attention mechanism is divided.
mlp_ratio (float): Expansion ratio for MLP hidden dimension.
area (int): Number of areas the feature map is divided.
"""
super().__init__()
self.attn = AAttn(dim, num_heads=num_heads, area=area)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False))
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights using a truncated normal distribution."""
if isinstance(m, nn.Conv2d):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
"""Forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor."""
x = x + self.attn(x)
return x + self.mlp(x)
class A2C2f(nn.Module):
"""
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
This module extends the C2f architecture by incorporating area-attention and ABlock layers for improved feature
processing. It supports both area-attention and standard convolution modes.
Attributes:
cv1 (Conv): Initial 1x1 convolution layer that reduces input channels to hidden channels.
cv2 (Conv): Final 1x1 convolution layer that processes concatenated features.
gamma (nn.Parameter | None): Learnable parameter for residual scaling when using area attention.
m (nn.ModuleList): List of either ABlock or C3k modules for feature processing.
Methods:
forward: Processes input through area-attention or standard convolution pathway.
Examples:
>>> m = A2C2f(512, 512, n=1, a2=True, area=1)
>>> x = torch.randn(1, 512, 32, 32)
>>> output = m(x)
>>> print(output.shape)
torch.Size([1, 512, 32, 32])
"""
def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True):
"""
Area-Attention C2f module for enhanced feature extraction with area-based attention mechanisms.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
n (int): Number of ABlock or C3k modules to stack.
a2 (bool): Whether to use area attention blocks. If False, uses C3k blocks instead.
area (int): Number of areas the feature map is divided.
residual (bool): Whether to use residual connections with learnable gamma parameter.
mlp_ratio (float): Expansion ratio for MLP hidden dimension.
e (float): Channel expansion ratio for hidden channels.
g (int): Number of groups for grouped convolutions.
shortcut (bool): Whether to use shortcut connections in C3k blocks.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32."
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv((1 + n) * c_, c2, 1)
self.gamma = nn.Parameter(0.01 * torch.ones(c2), requires_grad=True) if a2 and residual else None
self.m = nn.ModuleList(
nn.Sequential(*(ABlock(c_, c_ // 32, mlp_ratio, area) for _ in range(2)))
if a2
else C3k(c_, c_, 2, shortcut, g)
for _ in range(n)
)
def forward(self, x):
"""Forward pass through R-ELAN layer."""
y = [self.cv1(x)]
y.extend(m(y[-1]) for m in self.m)
y = self.cv2(torch.cat(y, 1))
if self.gamma is not None:
return x + self.gamma.view(-1, len(self.gamma), 1, 1) * y
return y

View file

@ -22,6 +22,7 @@ from ultralytics.nn.modules import (
SPP,
SPPELAN,
SPPF,
A2C2f,
AConv,
ADown,
Bottleneck,
@ -985,6 +986,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
PSA,
SCDown,
C2fCIB,
A2C2f,
}
)
repeat_modules = frozenset( # modules with 'repeat' arguments
@ -1003,6 +1005,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C2fPSA,
C2fCIB,
C2PSA,
A2C2f,
}
)
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
@ -1034,6 +1037,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
legacy = False
if scale in "mlx":
args[3] = True
if m is A2C2f:
legacy = False
if scale in "lx": # for L/X sizes
args.extend((True, 1.2))
elif m is AIFI:
args = [ch[f], *args]
elif m in frozenset({HGStem, HGBlock}):

View file

@ -18,6 +18,7 @@ GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = (
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
+ [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]