ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)

Co-authored-by: Yash Khurana <ykhurana6@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Swamita Gupta <swamita2001@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-01-05 03:00:26 +01:00 committed by GitHub
parent f702b34a50
commit 072291bc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 2090 additions and 524 deletions

View file

@ -21,7 +21,7 @@ from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP,
HGBlock, HGStem, Proto, RepC3, ResNetLayer)
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
GhostConv, LightConv, RepConv, SpatialAttention)
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
@ -30,4 +30,5 @@ __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer')
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
'OBB')

View file

@ -7,14 +7,14 @@ import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init_
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
class Detect(nn.Module):
@ -41,22 +41,24 @@ class Detect(nn.Module):
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
if self.training: # Training path
return x
elif self.dynamic or self.shape != shape:
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
@ -79,6 +81,10 @@ class Detect(nn.Module):
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes):
"""Decode bounding boxes."""
return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
class Segment(Detect):
"""YOLOv8 Segment head for segmentation models."""
@ -106,6 +112,35 @@ class Segment(Detect):
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()):
super().__init__(nc, ch)
self.ne = ne # number of extra parameters
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.ne)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
def forward(self, x):
bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
if not self.training:
self.angle = angle
x = self.detect(self, x)
if self.training:
return x, angle
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
def decode_bboxes(self, bboxes):
"""Decode rotated bounding boxes."""
return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides
class Pose(Detect):
"""YOLOv8 Pose head for keypoints models."""