ultralytics 8.2.38 official YOLOv10 support (#13113)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Burhan 2024-06-20 14:31:48 -04:00 committed by GitHub
parent 821e5fa477
commit ffb46fd7fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 785 additions and 32 deletions

View file

@ -22,18 +22,22 @@ from .block import (
C2,
C3,
C3TR,
CIB,
DFL,
ELAN1,
PSA,
SPP,
SPPELAN,
SPPF,
AConv,
ADown,
Attention,
BNContrastiveHead,
Bottleneck,
BottleneckCSP,
C2f,
C2fAttn,
C2fCIB,
C3Ghost,
C3x,
CBFuse,
@ -46,7 +50,9 @@ from .block import (
Proto,
RepC3,
RepNCSPELAN4,
RepVGGDW,
ResNetLayer,
SCDown,
)
from .conv import (
CBAM,
@ -63,7 +69,7 @@ from .conv import (
RepConv,
SpatialAttention,
)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect
from .transformer import (
AIFI,
MLP,
@ -137,4 +143,10 @@ __all__ = (
"CBLinear",
"AConv",
"ELAN1",
"RepVGGDW",
"CIB",
"C2fCIB",
"Attention",
"PSA",
"SCDown",
)

View file

@ -5,6 +5,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.torch_utils import fuse_conv_and_bn
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
from .transformer import TransformerBlock
@ -39,6 +41,12 @@ __all__ = (
"CBFuse",
"CBLinear",
"Silence",
"RepVGGDW",
"CIB",
"C2fCIB",
"Attention",
"PSA",
"SCDown",
)
@ -699,3 +707,251 @@ class CBFuse(nn.Module):
target_size = xs[-1].shape[2:]
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
return torch.sum(torch.stack(res + xs[-1:]), dim=0)
class RepVGGDW(torch.nn.Module):
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
self.dim = ed
self.act = nn.SiLU()
def forward(self, x):
"""
Performs a forward pass of the RepVGGDW block.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
"""
return self.act(self.conv(x) + self.conv1(x))
def forward_fuse(self, x):
"""
Performs a forward pass of the RepVGGDW block without fusing the convolutions.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
"""
return self.act(self.conv(x))
@torch.no_grad()
def fuse(self):
"""
Fuses the convolutional layers in the RepVGGDW block.
This method fuses the convolutional layers and updates the weights and biases accordingly.
"""
conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
final_conv_w = conv_w + conv1_w
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
self.conv = conv
del self.conv1
class CIB(nn.Module):
"""
Conditional Identity Block (CIB) module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
"""
def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
"""Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = nn.Sequential(
Conv(c1, c1, 3, g=c1),
Conv(c1, 2 * c_, 1),
Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_),
Conv(2 * c_, c2, 1),
Conv(c2, c2, 3, g=c2),
)
self.add = shortcut and c1 == c2
def forward(self, x):
"""
Forward pass of the CIB module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
return x + self.cv1(x) if self.add else self.cv1(x)
class C2fCIB(C2f):
"""
C2fCIB class represents a convolutional block with C2f and CIB modules.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
n (int, optional): Number of CIB modules to stack. Defaults to 1.
shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
lk (bool, optional): Whether to use local key connection. Defaults to False.
g (int, optional): Number of groups for grouped convolution. Defaults to 1.
e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
"""
def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
"""Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
class Attention(nn.Module):
"""
Attention module that performs self-attention on the input tensor.
Args:
dim (int): The input tensor dimension.
num_heads (int): The number of attention heads.
attn_ratio (float): The ratio of the attention key dimension to the head dimension.
Attributes:
num_heads (int): The number of attention heads.
head_dim (int): The dimension of each attention head.
key_dim (int): The dimension of the attention key.
scale (float): The scaling factor for the attention scores.
qkv (Conv): Convolutional layer for computing the query, key, and value.
proj (Conv): Convolutional layer for projecting the attended values.
pe (Conv): Convolutional layer for positional encoding.
"""
def __init__(self, dim, num_heads=8, attn_ratio=0.5):
"""Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.key_dim = int(self.head_dim * attn_ratio)
self.scale = self.key_dim**-0.5
nh_kd = nh_kd = self.key_dim * num_heads
h = dim + nh_kd * 2
self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x):
"""
Forward pass of the Attention module.
Args:
x (torch.Tensor): The input tensor.
Returns:
(torch.Tensor): The output tensor after self-attention.
"""
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x)
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
[self.key_dim, self.key_dim, self.head_dim], dim=2
)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
x = self.proj(x)
return x
class PSA(nn.Module):
"""
Position-wise Spatial Attention module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
e (float): Expansion factor for the intermediate channels. Default is 0.5.
Attributes:
c (int): Number of intermediate channels.
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
attn (Attention): Attention module for spatial attention.
ffn (nn.Sequential): Feed-forward network module.
"""
def __init__(self, c1, c2, e=0.5):
"""Initializes convolution layers, attention module, and feed-forward network with channel reduction."""
super().__init__()
assert c1 == c2
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
def forward(self, x):
"""
Forward pass of the PSA module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor.
"""
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = b + self.attn(b)
b = b + self.ffn(b)
return self.cv2(torch.cat((a, b), 1))
class SCDown(nn.Module):
def __init__(self, c1, c2, k, s):
"""
Spatial Channel Downsample (SCDown) module.
Args:
c1 (int): Number of input channels.
c2 (int): Number of output channels.
k (int): Kernel size for the convolutional layer.
s (int): Stride for the convolutional layer.
"""
super().__init__()
self.cv1 = Conv(c1, c2, 1, 1)
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
def forward(self, x):
"""
Forward pass of the SCDown module.
Args:
x (torch.Tensor): Input tensor.
Returns:
(torch.Tensor): Output tensor after applying the SCDown module.
"""
return self.cv2(self.cv1(x))

View file

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Model head modules."""
import copy
import math
import torch
@ -14,7 +15,7 @@ from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
class Detect(nn.Module):
@ -22,6 +23,8 @@ class Detect(nn.Module):
dynamic = False # force grid reconstruction
export = False # export mode
end2end = False # end2end
max_det = 300 # max_det
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
@ -41,13 +44,48 @@ class Detect(nn.Module):
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
if self.end2end:
self.one2one_cv2 = copy.deepcopy(self.cv2)
self.one2one_cv3 = copy.deepcopy(self.cv3)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
if self.end2end:
return self.forward_end2end(x)
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return x
y = self._inference(x)
return y if self.export else (y, x)
def forward_end2end(self, x):
"""
Performs forward pass of the v10Detect module.
Args:
x (tensor): Input tensor.
Returns:
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
"""
x_detach = [xi.detach() for xi in x]
one2one = [
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
]
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return {"one2many": x, "one2one": one2one}
y = self._inference(one2one)
y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
return y if self.export else (y, {"one2many": x, "one2one": one2one})
def _inference(self, x):
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
@ -73,7 +111,7 @@ class Detect(nn.Module):
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
return y
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
@ -83,10 +121,47 @@ class Detect(nn.Module):
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
if self.end2end:
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
return dist2bbox(bboxes, anchors, xywh=not self.end2end, dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
"""
Post-processes the predictions obtained from a YOLOv10 model.
Args:
preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
max_det (int): The maximum number of detections to keep.
nc (int, optional): The number of classes. Defaults to 80.
Returns:
(torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
including bounding boxes, scores and cls.
"""
assert 4 + nc == preds.shape[-1]
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
# NOTE: simplify but result slightly lower mAP
# scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
class Segment(Detect):
@ -487,3 +562,39 @@ class RTDETRDecoder(nn.Module):
xavier_uniform_(self.query_pos_head.layers[1].weight)
for layer in self.input_proj:
xavier_uniform_(layer[0].weight)
class v10Detect(Detect):
"""
v10 Detection head from https://arxiv.org/pdf/2405.14458
Args:
nc (int): Number of classes.
ch (tuple): Tuple of channel sizes.
Attributes:
max_det (int): Maximum number of detections.
Methods:
__init__(self, nc=80, ch=()): Initializes the v10Detect object.
forward(self, x): Performs forward pass of the v10Detect module.
bias_init(self): Initializes biases of the Detect module.
"""
end2end = True
def __init__(self, nc=80, ch=()):
"""Initializes the v10Detect object with the specified number of classes and input channels."""
super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100)) # channels
# Light cls head
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)),
nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)),
nn.Conv2d(c3, self.nc, 1),
)
for x in ch
)
self.one2one_cv3 = copy.deepcopy(self.cv3)

View file

@ -15,6 +15,7 @@ from ultralytics.nn.modules import (
C3TR,
ELAN1,
OBB,
PSA,
SPP,
SPPELAN,
SPPF,
@ -24,6 +25,7 @@ from ultralytics.nn.modules import (
BottleneckCSP,
C2f,
C2fAttn,
C2fCIB,
C3Ghost,
C3x,
CBFuse,
@ -46,14 +48,24 @@ from ultralytics.nn.modules import (
RepC3,
RepConv,
RepNCSPELAN4,
RepVGGDW,
ResNetLayer,
RTDETRDecoder,
SCDown,
Segment,
WorldDetect,
v10Detect,
)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.loss import (
E2EDetectLoss,
v8ClassificationLoss,
v8DetectionLoss,
v8OBBLoss,
v8PoseLoss,
v8SegmentationLoss,
)
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (
fuse_conv_and_bn,
@ -192,6 +204,9 @@ class BaseModel(nn.Module):
if isinstance(m, RepConv):
m.fuse_convs()
m.forward = m.forward_fuse # update forward
if isinstance(m, RepVGGDW):
m.fuse()
m.forward = m.forward_fuse
self.info(verbose=verbose)
return self
@ -294,6 +309,7 @@ class DetectionModel(BaseModel):
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
self.inplace = self.yaml.get("inplace", True)
self.end2end = getattr(self.model[-1], "end2end", False)
# Build strides
m = self.model[-1] # Detect()
@ -303,6 +319,8 @@ class DetectionModel(BaseModel):
def _forward(x):
"""Performs a forward pass through the model, handling different Detect subclass types accordingly."""
if self.end2end:
return self.forward(x)["one2many"]
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
@ -355,7 +373,7 @@ class DetectionModel(BaseModel):
def init_criterion(self):
"""Initialize the loss criterion for the DetectionModel."""
return v8DetectionLoss(self)
return E2EDetectLoss(self) if self.end2end else v8DetectionLoss(self)
class OBBModel(DetectionModel):
@ -689,8 +707,8 @@ def temporary_modules(modules={}, attributes={}):
Example:
```python
with temporary_modules({'old.module.path': 'new.module.path'}, {'old.module.attribute': 'new.module.attribute'}):
import old.module.path # this will now import new.module.path
with temporary_modules({'old.module': 'new.module'}, {'old.module.attribute': 'new.module.attribute'}):
import old.module # this will now import new.module
from old.module import attribute # this will now import new.module.attribute
```
@ -700,23 +718,19 @@ def temporary_modules(modules={}, attributes={}):
applications or libraries. Use this function with caution.
"""
import importlib
import sys
from importlib import import_module
try:
# Set attributes in sys.modules under their old name
for old, new in attributes.items():
old_module, old_attr = old.rsplit(".", 1)
new_module, new_attr = new.rsplit(".", 1)
setattr(
importlib.import_module(old_module),
old_attr,
getattr(importlib.import_module(new_module), new_attr),
)
setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr))
# Set modules in sys.modules under their old name
for old, new in modules.items():
sys.modules[old] = importlib.import_module(new)
sys.modules[old] = import_module(new)
yield
finally:
@ -750,9 +764,10 @@ def torch_safe_load(weight):
"ultralytics.yolo.data": "ultralytics.data",
},
attributes={
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity",
"ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e
"ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10
},
): # for legacy 8.0 Classify and Pose models
):
ckpt = torch.load(file, map_location="cpu")
except ModuleNotFoundError as e: # e.name is missing module name
@ -911,6 +926,9 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
DWConvTranspose2d,
C3x,
RepC3,
PSA,
SCDown,
C2fCIB,
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
@ -922,7 +940,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
) # num heads
args = [c1, c2, *args[1:]]
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fCIB}:
args.insert(2, n) # number of repeats
n = 1
elif m is AIFI:
@ -939,7 +957,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1024,7 +1042,7 @@ def guess_model_task(model):
m = cfg["head"][-1][-2].lower() # output module name
if m in {"classify", "classifier", "cls", "fc"}:
return "classify"
if m == "detect":
if "detect" in m:
return "detect"
if m == "segment":
return "segment"
@ -1056,7 +1074,7 @@ def guess_model_task(model):
return "pose"
elif isinstance(m, OBB):
return "obb"
elif isinstance(m, (Detect, WorldDetect)):
elif isinstance(m, (Detect, WorldDetect, v10Detect)):
return "detect"
# Guess from model filename