Backward compatibility support for legacy models (#17010)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8d7d1fe390
commit
899fb0495e
2 changed files with 18 additions and 8 deletions
|
|
@ -28,6 +28,7 @@ class Detect(nn.Module):
|
||||||
shape = None
|
shape = None
|
||||||
anchors = torch.empty(0) # init
|
anchors = torch.empty(0) # init
|
||||||
strides = torch.empty(0) # init
|
strides = torch.empty(0) # init
|
||||||
|
legacy = False # backward compatibility for v3/v5/v8/v9 models
|
||||||
|
|
||||||
def __init__(self, nc=80, ch=()):
|
def __init__(self, nc=80, ch=()):
|
||||||
"""Initializes the YOLO detection layer with specified number of classes and channels."""
|
"""Initializes the YOLO detection layer with specified number of classes and channels."""
|
||||||
|
|
@ -41,13 +42,17 @@ class Detect(nn.Module):
|
||||||
self.cv2 = nn.ModuleList(
|
self.cv2 = nn.ModuleList(
|
||||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||||
)
|
)
|
||||||
self.cv3 = nn.ModuleList(
|
self.cv3 = (
|
||||||
nn.Sequential(
|
nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
if self.legacy
|
||||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
else nn.ModuleList(
|
||||||
nn.Conv2d(c3, self.nc, 1),
|
nn.Sequential(
|
||||||
|
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||||
|
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||||
|
nn.Conv2d(c3, self.nc, 1),
|
||||||
|
)
|
||||||
|
for x in ch
|
||||||
)
|
)
|
||||||
for x in ch
|
|
||||||
)
|
)
|
||||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -936,6 +936,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
# Args
|
# Args
|
||||||
|
legacy = True # backward compatibility for v3/v5/v8/v9 models
|
||||||
max_channels = float("inf")
|
max_channels = float("inf")
|
||||||
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
||||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
||||||
|
|
@ -1027,8 +1028,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
}:
|
}:
|
||||||
args.insert(2, n) # number of repeats
|
args.insert(2, n) # number of repeats
|
||||||
n = 1
|
n = 1
|
||||||
if m is C3k2 and scale in "mlx": # for M/L/X sizes
|
if m is C3k2: # for M/L/X sizes
|
||||||
args[3] = True
|
legacy = False
|
||||||
|
if scale in "mlx":
|
||||||
|
args[3] = True
|
||||||
elif m is AIFI:
|
elif m is AIFI:
|
||||||
args = [ch[f], *args]
|
args = [ch[f], *args]
|
||||||
elif m in {HGStem, HGBlock}:
|
elif m in {HGStem, HGBlock}:
|
||||||
|
|
@ -1047,6 +1050,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
args.append([ch[x] for x in f])
|
args.append([ch[x] for x in f])
|
||||||
if m is Segment:
|
if m is Segment:
|
||||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||||
|
if m in {Detect, Segment, Pose, OBB}:
|
||||||
|
m.legacy = legacy
|
||||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||||
args.insert(1, [ch[x] for x in f])
|
args.insert(1, [ch[x] for x in f])
|
||||||
elif m is CBLinear:
|
elif m is CBLinear:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue