ultralytics 8.2.38 official YOLOv10 support (#13113)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
821e5fa477
commit
ffb46fd7fb
23 changed files with 785 additions and 32 deletions
|
|
@ -5,6 +5,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.torch_utils import fuse_conv_and_bn
|
||||
|
||||
from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad
|
||||
from .transformer import TransformerBlock
|
||||
|
||||
|
|
@ -39,6 +41,12 @@ __all__ = (
|
|||
"CBFuse",
|
||||
"CBLinear",
|
||||
"Silence",
|
||||
"RepVGGDW",
|
||||
"CIB",
|
||||
"C2fCIB",
|
||||
"Attention",
|
||||
"PSA",
|
||||
"SCDown",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -699,3 +707,251 @@ class CBFuse(nn.Module):
|
|||
target_size = xs[-1].shape[2:]
|
||||
res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])]
|
||||
return torch.sum(torch.stack(res + xs[-1:]), dim=0)
|
||||
|
||||
|
||||
class RepVGGDW(torch.nn.Module):
|
||||
"""RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture."""
|
||||
|
||||
def __init__(self, ed) -> None:
|
||||
super().__init__()
|
||||
self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False)
|
||||
self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False)
|
||||
self.dim = ed
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Performs a forward pass of the RepVGGDW block.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
|
||||
"""
|
||||
return self.act(self.conv(x) + self.conv1(x))
|
||||
|
||||
def forward_fuse(self, x):
|
||||
"""
|
||||
Performs a forward pass of the RepVGGDW block without fusing the convolutions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor after applying the depth wise separable convolution.
|
||||
"""
|
||||
return self.act(self.conv(x))
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse(self):
|
||||
"""
|
||||
Fuses the convolutional layers in the RepVGGDW block.
|
||||
|
||||
This method fuses the convolutional layers and updates the weights and biases accordingly.
|
||||
"""
|
||||
conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn)
|
||||
conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn)
|
||||
|
||||
conv_w = conv.weight
|
||||
conv_b = conv.bias
|
||||
conv1_w = conv1.weight
|
||||
conv1_b = conv1.bias
|
||||
|
||||
conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2])
|
||||
|
||||
final_conv_w = conv_w + conv1_w
|
||||
final_conv_b = conv_b + conv1_b
|
||||
|
||||
conv.weight.data.copy_(final_conv_w)
|
||||
conv.bias.data.copy_(final_conv_b)
|
||||
|
||||
self.conv = conv
|
||||
del self.conv1
|
||||
|
||||
|
||||
class CIB(nn.Module):
|
||||
"""
|
||||
Conditional Identity Block (CIB) module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True.
|
||||
e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5.
|
||||
lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False):
|
||||
"""Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer."""
|
||||
super().__init__()
|
||||
c_ = int(c2 * e) # hidden channels
|
||||
self.cv1 = nn.Sequential(
|
||||
Conv(c1, c1, 3, g=c1),
|
||||
Conv(c1, 2 * c_, 1),
|
||||
Conv(2 * c_, 2 * c_, 3, g=2 * c_) if not lk else RepVGGDW(2 * c_),
|
||||
Conv(2 * c_, c2, 1),
|
||||
Conv(c2, c2, 3, g=c2),
|
||||
)
|
||||
|
||||
self.add = shortcut and c1 == c2
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the CIB module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor.
|
||||
"""
|
||||
return x + self.cv1(x) if self.add else self.cv1(x)
|
||||
|
||||
|
||||
class C2fCIB(C2f):
|
||||
"""
|
||||
C2fCIB class represents a convolutional block with C2f and CIB modules.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
n (int, optional): Number of CIB modules to stack. Defaults to 1.
|
||||
shortcut (bool, optional): Whether to use shortcut connection. Defaults to False.
|
||||
lk (bool, optional): Whether to use local key connection. Defaults to False.
|
||||
g (int, optional): Number of groups for grouped convolution. Defaults to 1.
|
||||
e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5):
|
||||
"""Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion."""
|
||||
super().__init__(c1, c2, n, shortcut, g, e)
|
||||
self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Attention module that performs self-attention on the input tensor.
|
||||
|
||||
Args:
|
||||
dim (int): The input tensor dimension.
|
||||
num_heads (int): The number of attention heads.
|
||||
attn_ratio (float): The ratio of the attention key dimension to the head dimension.
|
||||
|
||||
Attributes:
|
||||
num_heads (int): The number of attention heads.
|
||||
head_dim (int): The dimension of each attention head.
|
||||
key_dim (int): The dimension of the attention key.
|
||||
scale (float): The scaling factor for the attention scores.
|
||||
qkv (Conv): Convolutional layer for computing the query, key, and value.
|
||||
proj (Conv): Convolutional layer for projecting the attended values.
|
||||
pe (Conv): Convolutional layer for positional encoding.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_heads=8, attn_ratio=0.5):
|
||||
"""Initializes multi-head attention module with query, key, and value convolutions and positional encoding."""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.key_dim = int(self.head_dim * attn_ratio)
|
||||
self.scale = self.key_dim**-0.5
|
||||
nh_kd = nh_kd = self.key_dim * num_heads
|
||||
h = dim + nh_kd * 2
|
||||
self.qkv = Conv(dim, h, 1, act=False)
|
||||
self.proj = Conv(dim, dim, 1, act=False)
|
||||
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the Attention module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output tensor after self-attention.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
N = H * W
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split(
|
||||
[self.key_dim, self.key_dim, self.head_dim], dim=2
|
||||
)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class PSA(nn.Module):
|
||||
"""
|
||||
Position-wise Spatial Attention module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
e (float): Expansion factor for the intermediate channels. Default is 0.5.
|
||||
|
||||
Attributes:
|
||||
c (int): Number of intermediate channels.
|
||||
cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c.
|
||||
cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c.
|
||||
attn (Attention): Attention module for spatial attention.
|
||||
ffn (nn.Sequential): Feed-forward network module.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, e=0.5):
|
||||
"""Initializes convolution layers, attention module, and feed-forward network with channel reduction."""
|
||||
super().__init__()
|
||||
assert c1 == c2
|
||||
self.c = int(c1 * e)
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv(2 * self.c, c1, 1)
|
||||
|
||||
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
|
||||
self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the PSA module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor.
|
||||
"""
|
||||
a, b = self.cv1(x).split((self.c, self.c), dim=1)
|
||||
b = b + self.attn(b)
|
||||
b = b + self.ffn(b)
|
||||
return self.cv2(torch.cat((a, b), 1))
|
||||
|
||||
|
||||
class SCDown(nn.Module):
|
||||
def __init__(self, c1, c2, k, s):
|
||||
"""
|
||||
Spatial Channel Downsample (SCDown) module.
|
||||
|
||||
Args:
|
||||
c1 (int): Number of input channels.
|
||||
c2 (int): Number of output channels.
|
||||
k (int): Kernel size for the convolutional layer.
|
||||
s (int): Stride for the convolutional layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c2, 1, 1)
|
||||
self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass of the SCDown module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor after applying the SCDown module.
|
||||
"""
|
||||
return self.cv2(self.cv1(x))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue