Backward compatibility support for legacy models (#17010)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8d7d1fe390
commit
899fb0495e
2 changed files with 18 additions and 8 deletions
|
|
@ -28,6 +28,7 @@ class Detect(nn.Module):
|
|||
shape = None
|
||||
anchors = torch.empty(0) # init
|
||||
strides = torch.empty(0) # init
|
||||
legacy = False # backward compatibility for v3/v5/v8/v9 models
|
||||
|
||||
def __init__(self, nc=80, ch=()):
|
||||
"""Initializes the YOLO detection layer with specified number of classes and channels."""
|
||||
|
|
@ -41,13 +42,17 @@ class Detect(nn.Module):
|
|||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
|
||||
)
|
||||
self.cv3 = nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, self.nc, 1),
|
||||
self.cv3 = (
|
||||
nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
||||
if self.legacy
|
||||
else nn.ModuleList(
|
||||
nn.Sequential(
|
||||
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
|
||||
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
|
||||
nn.Conv2d(c3, self.nc, 1),
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
for x in ch
|
||||
)
|
||||
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue