ultralytics 8.1.14 new YOLOv8-World models (#8054)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f9e9cdf2c3
commit
850ca8587f
19 changed files with 683 additions and 32 deletions
|
|
@ -28,6 +28,8 @@ from .block import (
|
|||
Bottleneck,
|
||||
BottleneckCSP,
|
||||
C2f,
|
||||
C2fAttn,
|
||||
ImagePoolingAttn,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
GhostBottleneck,
|
||||
|
|
@ -36,6 +38,8 @@ from .block import (
|
|||
Proto,
|
||||
RepC3,
|
||||
ResNetLayer,
|
||||
ContrastiveHead,
|
||||
BNContrastiveHead,
|
||||
)
|
||||
from .conv import (
|
||||
CBAM,
|
||||
|
|
@ -52,7 +56,7 @@ from .conv import (
|
|||
RepConv,
|
||||
SpatialAttention,
|
||||
)
|
||||
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
|
||||
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
|
||||
from .transformer import (
|
||||
AIFI,
|
||||
MLP,
|
||||
|
|
@ -93,6 +97,7 @@ __all__ = (
|
|||
"C2",
|
||||
"C3",
|
||||
"C2f",
|
||||
"C2fAttn",
|
||||
"C3x",
|
||||
"C3TR",
|
||||
"C3Ghost",
|
||||
|
|
@ -114,4 +119,8 @@ __all__ = (
|
|||
"MLP",
|
||||
"ResNetLayer",
|
||||
"OBB",
|
||||
"WorldDetect",
|
||||
"ImagePoolingAttn",
|
||||
"ContrastiveHead",
|
||||
"BNContrastiveHead",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ __all__ = (
|
|||
"C2",
|
||||
"C3",
|
||||
"C2f",
|
||||
"C2fAttn",
|
||||
"ImagePoolingAttn",
|
||||
"ContrastiveHead",
|
||||
"BNContrastiveHead",
|
||||
"C3x",
|
||||
"C3TR",
|
||||
"C3Ghost",
|
||||
|
|
@ -390,3 +394,157 @@ class ResNetLayer(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Forward pass through the ResNet layer."""
|
||||
return self.layer(x)
|
||||
|
||||
|
||||
class MaxSigmoidAttnBlock(nn.Module):
|
||||
"""Max Sigmoid attention block."""
|
||||
|
||||
def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False):
|
||||
"""Initializes MaxSigmoidAttnBlock with specified arguments."""
|
||||
super().__init__()
|
||||
self.nh = nh
|
||||
self.hc = c2 // nh
|
||||
self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None
|
||||
self.gl = nn.Linear(gc, ec)
|
||||
self.bias = nn.Parameter(torch.zeros(nh))
|
||||
self.proj_conv = Conv(c1, c2, k=3, s=1, act=False)
|
||||
self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0
|
||||
|
||||
def forward(self, x, guide):
|
||||
"""Forward process."""
|
||||
bs, _, h, w = x.shape
|
||||
|
||||
guide = self.gl(guide)
|
||||
guide = guide.view(bs, -1, self.nh, self.hc)
|
||||
embed = self.ec(x) if self.ec is not None else x
|
||||
embed = embed.view(bs, self.nh, self.hc, h, w)
|
||||
|
||||
aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide)
|
||||
aw = aw.max(dim=-1)[0]
|
||||
aw = aw / (self.hc**0.5)
|
||||
aw = aw + self.bias[None, :, None, None]
|
||||
aw = aw.sigmoid() * self.scale
|
||||
|
||||
x = self.proj_conv(x)
|
||||
x = x.view(bs, self.nh, -1, h, w)
|
||||
x = x * aw.unsqueeze(2)
|
||||
return x.view(bs, -1, h, w)
|
||||
|
||||
|
||||
class C2fAttn(nn.Module):
|
||||
"""C2f module with an additional attn module."""
|
||||
|
||||
def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5):
|
||||
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
|
||||
expansion.
|
||||
"""
|
||||
super().__init__()
|
||||
self.c = int(c2 * e) # hidden channels
|
||||
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
||||
self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2)
|
||||
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
||||
self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh)
|
||||
|
||||
def forward(self, x, guide):
|
||||
"""Forward pass through C2f layer."""
|
||||
y = list(self.cv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
y.append(self.attn(y[-1], guide))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
def forward_split(self, x, guide):
|
||||
"""Forward pass using split() instead of chunk()."""
|
||||
y = list(self.cv1(x).split((self.c, self.c), 1))
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
y.append(self.attn(y[-1], guide))
|
||||
return self.cv2(torch.cat(y, 1))
|
||||
|
||||
|
||||
class ImagePoolingAttn(nn.Module):
|
||||
"""ImagePoolingAttn: Enhance the text embeddings with image-aware information."""
|
||||
|
||||
def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False):
|
||||
"""Initializes ImagePoolingAttn with specified arguments."""
|
||||
super().__init__()
|
||||
|
||||
nf = len(ch)
|
||||
self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec))
|
||||
self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
|
||||
self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec))
|
||||
self.proj = nn.Linear(ec, ct)
|
||||
self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0
|
||||
self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch])
|
||||
self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)])
|
||||
self.ec = ec
|
||||
self.nh = nh
|
||||
self.nf = nf
|
||||
self.hc = ec // nh
|
||||
self.k = k
|
||||
|
||||
def forward(self, x, text):
|
||||
"""Executes attention mechanism on input tensor x and guide tensor."""
|
||||
bs = x[0].shape[0]
|
||||
assert len(x) == self.nf
|
||||
num_patches = self.k**2
|
||||
x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)]
|
||||
x = torch.cat(x, dim=-1).transpose(1, 2)
|
||||
q = self.query(text)
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
|
||||
# q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1)
|
||||
q = q.reshape(bs, -1, self.nh, self.hc)
|
||||
k = k.reshape(bs, -1, self.nh, self.hc)
|
||||
v = v.reshape(bs, -1, self.nh, self.hc)
|
||||
|
||||
aw = torch.einsum("bnmc,bkmc->bmnk", q, k)
|
||||
aw = aw / (self.hc**0.5)
|
||||
aw = F.softmax(aw, dim=-1)
|
||||
|
||||
x = torch.einsum("bmnk,bkmc->bnmc", aw, v)
|
||||
x = self.proj(x.reshape(bs, -1, self.ec))
|
||||
return x * self.scale + text
|
||||
|
||||
|
||||
class ContrastiveHead(nn.Module):
|
||||
"""Contrastive Head for YOLO-World compute the region-text scores according to the similarity between image and text
|
||||
features.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros([]))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
|
||||
|
||||
def forward(self, x, w):
|
||||
"""Forward function of contrastive learning."""
|
||||
x = F.normalize(x, dim=1, p=2)
|
||||
w = F.normalize(w, dim=-1, p=2)
|
||||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||||
return x * self.logit_scale.exp() + self.bias
|
||||
|
||||
|
||||
class BNContrastiveHead(nn.Module):
|
||||
"""
|
||||
Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization.
|
||||
|
||||
Args:
|
||||
embed_dims (int): Embed dimensions of text and image features.
|
||||
norm_cfg (dict): Normalization parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dims: int):
|
||||
"""Initialize ContrastiveHead with region-text similarity parameters."""
|
||||
super().__init__()
|
||||
self.norm = nn.BatchNorm2d(embed_dims)
|
||||
self.bias = nn.Parameter(torch.zeros([]))
|
||||
# use -1.0 is more stable
|
||||
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))
|
||||
|
||||
def forward(self, x, w):
|
||||
"""Forward function of contrastive learning."""
|
||||
x = self.norm(x)
|
||||
w = F.normalize(w, dim=-1, p=2)
|
||||
x = torch.einsum("bchw,bkc->bkhw", x, w)
|
||||
return x * self.logit_scale.exp() + self.bias
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
|
||||
from .block import DFL, Proto
|
||||
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
|
||||
from .conv import Conv
|
||||
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
|
||||
from .utils import bias_init_with_prob, linear_init
|
||||
|
|
@ -208,6 +208,49 @@ class Classify(nn.Module):
|
|||
return x if self.training else x.softmax(1)
|
||||
|
||||
|
||||
class WorldDetect(Detect):
|
||||
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
|
||||
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
|
||||
super().__init__(nc, ch)
|
||||
c3 = max(ch[0], min(self.nc, 100))
|
||||
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
|
||||
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
|
||||
|
||||
def forward(self, x, text):
|
||||
"""Concatenates and returns predicted bounding boxes and class probabilities."""
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
# Inference path
|
||||
shape = x[0].shape # BCHW
|
||||
x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
|
||||
if self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
grid_w = shape[3]
|
||||
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
|
||||
norm = self.strides / (self.stride[0] * grid_size)
|
||||
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
|
||||
else:
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, x)
|
||||
|
||||
|
||||
class RTDETRDecoder(nn.Module):
|
||||
"""
|
||||
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue