ultralytics 8.0.98 add Baidu RT-DETR models (#2527)

Co-authored-by: Kalen Michael <kalenmike@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Dowon <ks2515@naver.com>
This commit is contained in:
Glenn Jocher 2023-05-11 01:15:20 +02:00 committed by GitHub
parent d1107ca4cb
commit 229119c376
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 2172 additions and 834 deletions

View file

@ -8,9 +8,10 @@ import thop
import torch
import torch.nn as nn
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
GhostBottleneck, GhostConv, Pose, Segment)
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.plotting import feature_visualization
@ -105,6 +106,9 @@ class BaseModel(nn.Module):
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
if isinstance(m, RepConv):
m.fuse_convs()
m.forward = m.forward_fuse # update forward
self.info(verbose=verbose)
return self
@ -334,6 +338,22 @@ class ClassificationModel(BaseModel):
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
class Ensemble(nn.ModuleList):
"""Ensemble of models."""
def __init__(self):
"""Initialize an ensemble of models."""
super().__init__()
def forward(self, x, augment=False, profile=False, visualize=False):
"""Function generates the YOLOv5 network's final layer."""
y = [module(x, augment, profile, visualize)[0] for module in self]
# y = torch.stack(y).max(0)[0] # max ensemble
# y = torch.stack(y).mean(0) # mean ensemble
y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C)
return y, None # inference, train output
# Functions ------------------------------------------------------------------------------------------------------------
@ -415,7 +435,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
"""Loads a single model weights."""
ckpt, weight = torch_safe_load(weight) # load ckpt
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
# Model compatibility updates
@ -472,20 +492,29 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, c2, *args[1:]]
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x):
if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3):
args.insert(2, n) # number of repeats
n = 1
elif m is AIFI:
args = [ch[f], *args]
elif m in (HGStem, HGBlock):
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
args.insert(4, n) # number of repeats
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in (Detect, Segment, Pose):
elif m in (Detect, Segment, Pose, RTDETRDecoder):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)