ultralytics 8.0.65 YOLOv8 Pose models (#1347)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mert Can Demir <validatedev@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Fabian Greavu <fabiangreavu@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com> Co-authored-by: JustasBart <40023722+JustasBart@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com> Co-authored-by: Farid Inawan <frdteknikelektro@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
parent
9af3e69b1a
commit
1cb92d7f42
57 changed files with 1578 additions and 489 deletions
|
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
GhostBottleneck, GhostConv, Pose, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
|
|
@ -183,10 +183,10 @@ class DetectionModel(BaseModel):
|
|||
|
||||
# Build strides
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
if isinstance(m, (Detect, Segment, Pose)):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
|
|
@ -242,12 +242,23 @@ class DetectionModel(BaseModel):
|
|||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv8 segmentation model
|
||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg, ch, nc, verbose)
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def _forward_augment(self, x):
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
||||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
# YOLOv8 pose model
|
||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = yaml_model_load(cfg) # load model YAML
|
||||
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
||||
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
||||
cfg['kpt_shape'] = data_kpt_shape
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv8 classification model
|
||||
def __init__(self,
|
||||
|
|
@ -425,7 +436,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
# Args
|
||||
max_channels = float('inf')
|
||||
nc, act, scales = (d.get(x) for x in ('nc', 'act', 'scales'))
|
||||
depth, width = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple'))
|
||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
||||
if scales:
|
||||
scale = d.get('scale')
|
||||
if not scale:
|
||||
|
|
@ -464,7 +475,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, Segment):
|
||||
elif m in (Detect, Segment, Pose):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
|
|
@ -543,6 +554,8 @@ def guess_model_task(model):
|
|||
return 'detect'
|
||||
if m == 'segment':
|
||||
return 'segment'
|
||||
if m == 'pose':
|
||||
return 'pose'
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
|
|
@ -565,6 +578,8 @@ def guess_model_task(model):
|
|||
return 'segment'
|
||||
elif isinstance(m, Classify):
|
||||
return 'classify'
|
||||
elif isinstance(m, Pose):
|
||||
return 'pose'
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
|
|
@ -573,10 +588,12 @@ def guess_model_task(model):
|
|||
return 'segment'
|
||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||
return 'classify'
|
||||
elif '-pose' in model.stem or 'pose' in model.parts:
|
||||
return 'pose'
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
|
||||
return 'detect' # assume detect
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue