ultralytics 8.1.14 new YOLOv8-World models (#8054)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-02-14 18:46:26 +08:00 committed by GitHub
parent f9e9cdf2c3
commit 850ca8587f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 683 additions and 32 deletions

View file

@ -8,7 +8,7 @@ import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto
from .block import DFL, Proto, ContrastiveHead, BNContrastiveHead
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init
@ -208,6 +208,49 @@ class Classify(nn.Module):
return x if self.training else x.softmax(1)
class WorldDetect(Detect):
def __init__(self, nc=80, embed=512, with_bn=False, ch=()):
"""Initialize YOLOv8 detection layer with nc classes and layer channels ch."""
super().__init__(nc, ch)
c3 = max(ch[0], min(self.nc, 100))
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch)
self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch)
def forward(self, x, text):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1)
if self.training:
return x
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
if self.export and self.format in ("tflite", "edgetpu"):
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
grid_h = shape[2]
grid_w = shape[3]
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
else:
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
class RTDETRDecoder(nn.Module):
"""
Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection.