ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)

Co-authored-by: Yash Khurana <ykhurana6@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Swamita Gupta <swamita2001@gmail.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
Glenn Jocher 2024-01-05 03:00:26 +01:00 committed by GitHub
parent f702b34a50
commit 072291bc78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
52 changed files with 2090 additions and 524 deletions

View file

@ -21,7 +21,7 @@ from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP,
HGBlock, HGStem, Proto, RepC3, ResNetLayer)
from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
GhostConv, LightConv, RepConv, SpatialAttention)
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
@ -30,4 +30,5 @@ __all__ = ('Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d
'TransformerBlock', 'MLPBlock', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3',
'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect',
'Segment', 'Pose', 'Classify', 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI',
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer')
'DeformableTransformerDecoder', 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP', 'ResNetLayer',
'OBB')

View file

@ -7,14 +7,14 @@ import torch
import torch.nn as nn
from torch.nn.init import constant_, xavier_uniform_
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, make_anchors
from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors
from .block import DFL, Proto
from .conv import Conv
from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer
from .utils import bias_init_with_prob, linear_init_
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'RTDETRDecoder'
__all__ = 'Detect', 'Segment', 'Pose', 'Classify', 'OBB', 'RTDETRDecoder'
class Detect(nn.Module):
@ -41,22 +41,24 @@ class Detect(nn.Module):
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training:
if self.training: # Training path
return x
elif self.dynamic or self.shape != shape:
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): # avoid TF FlexSplitV ops
box = x_cat[:, :self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4:]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
@ -79,6 +81,10 @@ class Detect(nn.Module):
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes):
"""Decode bounding boxes."""
return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
class Segment(Detect):
"""YOLOv8 Segment head for segmentation models."""
@ -106,6 +112,35 @@ class Segment(Detect):
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
class OBB(Detect):
"""YOLOv8 OBB detection head for detection with rotation models."""
def __init__(self, nc=80, ne=1, ch=()):
super().__init__(nc, ch)
self.ne = ne # number of extra parameters
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.ne)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch)
def forward(self, x):
bs = x[0].shape[0] # batch size
angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits
# NOTE: set `angle` as an attribute so that `decode_bboxes` could use it.
angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]
# angle = angle.sigmoid() * math.pi / 2 # [0, pi/2]
if not self.training:
self.angle = angle
x = self.detect(self, x)
if self.training:
return x, angle
return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle))
def decode_bboxes(self, bboxes):
"""Decode rotated bounding boxes."""
return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides
class Pose(Detect):
"""YOLOv8 Pose head for keypoints models."""

View file

@ -7,13 +7,13 @@ from pathlib import Path
import torch
import torch.nn as nn
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
ResNetLayer, RTDETRDecoder, Segment)
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
RepConv, ResNetLayer, RTDETRDecoder, Segment)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
make_divisible, model_info, scale_img, time_sync)
@ -241,10 +241,10 @@ class DetectionModel(BaseModel):
# Build strides
m = self.model[-1] # Detect()
if isinstance(m, (Detect, Segment, Pose)):
if isinstance(m, (Detect, Segment, Pose, OBB)):
s = 256 # 2x min stride
m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.bias_init() # only run once
@ -298,6 +298,17 @@ class DetectionModel(BaseModel):
return v8DetectionLoss(self)
class OBBModel(DetectionModel):
""""YOLOv8 Oriented Bounding Box (OBB) model."""
def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 OBB model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def init_criterion(self):
return v8OBBLoss(self)
class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model."""
@ -616,7 +627,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
# Module updates
for m in ensemble.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
@ -652,7 +663,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
# Module updates
for m in model.modules():
t = type(m)
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
m.inplace = inplace
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None # torch 1.11.0 compatibility
@ -717,7 +728,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in (Detect, Segment, Pose):
elif m in (Detect, Segment, Pose, OBB):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -801,6 +812,8 @@ def guess_model_task(model):
return 'segment'
if m == 'pose':
return 'pose'
if m == 'obb':
return 'obb'
# Guess from model cfg
if isinstance(model, dict):
@ -825,6 +838,8 @@ def guess_model_task(model):
return 'classify'
elif isinstance(m, Pose):
return 'pose'
elif isinstance(m, OBB):
return 'obb'
# Guess from model filename
if isinstance(model, (str, Path)):
@ -835,10 +850,12 @@ def guess_model_task(model):
return 'classify'
elif '-pose' in model.stem or 'pose' in model.parts:
return 'pose'
elif '-obb' in model.stem or 'obb' in model.parts:
return 'obb'
elif 'detect' in model.parts:
return 'detect'
# Unable to determine task from model
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
return 'detect' # assume detect