ultralytics 8.0.235 YOLOv8 OBB train, val, predict and export (#4499)
Co-authored-by: Yash Khurana <ykhurana6@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Swamita Gupta <swamita2001@gmail.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
parent
f702b34a50
commit
072291bc78
52 changed files with 2090 additions and 524 deletions
|
|
@ -7,13 +7,13 @@ from pathlib import Path
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
|
||||
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
||||
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
||||
ResNetLayer, RTDETRDecoder, Segment)
|
||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
|
||||
C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
|
||||
DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
|
||||
RepConv, ResNetLayer, RTDETRDecoder, Segment)
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.utils.plotting import feature_visualization
|
||||
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
|
||||
make_divisible, model_info, scale_img, time_sync)
|
||||
|
|
@ -241,10 +241,10 @@ class DetectionModel(BaseModel):
|
|||
|
||||
# Build strides
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment, Pose)):
|
||||
if isinstance(m, (Detect, Segment, Pose, OBB)):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
|
|
@ -298,6 +298,17 @@ class DetectionModel(BaseModel):
|
|||
return v8DetectionLoss(self)
|
||||
|
||||
|
||||
class OBBModel(DetectionModel):
|
||||
""""YOLOv8 Oriented Bounding Box (OBB) model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 OBB model with given config and parameters."""
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def init_criterion(self):
|
||||
return v8OBBLoss(self)
|
||||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
"""YOLOv8 segmentation model."""
|
||||
|
||||
|
|
@ -616,7 +627,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
# Module updates
|
||||
for m in ensemble.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
|
@ -652,7 +663,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||
# Module updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment):
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
|
@ -717,7 +728,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, Segment, Pose):
|
||||
elif m in (Detect, Segment, Pose, OBB):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
|
|
@ -801,6 +812,8 @@ def guess_model_task(model):
|
|||
return 'segment'
|
||||
if m == 'pose':
|
||||
return 'pose'
|
||||
if m == 'obb':
|
||||
return 'obb'
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
|
|
@ -825,6 +838,8 @@ def guess_model_task(model):
|
|||
return 'classify'
|
||||
elif isinstance(m, Pose):
|
||||
return 'pose'
|
||||
elif isinstance(m, OBB):
|
||||
return 'obb'
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
|
|
@ -835,10 +850,12 @@ def guess_model_task(model):
|
|||
return 'classify'
|
||||
elif '-pose' in model.stem or 'pose' in model.parts:
|
||||
return 'pose'
|
||||
elif '-obb' in model.stem or 'obb' in model.parts:
|
||||
return 'obb'
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
|
||||
return 'detect' # assume detect
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue