ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
246c3eca81
commit
cc1e77138c
8 changed files with 104 additions and 2 deletions
|
|
@ -189,4 +189,8 @@ keywords: Ultralytics, YOLO, neural networks, block modules, DFL, Proto, HGStem,
|
||||||
|
|
||||||
## ::: ultralytics.nn.modules.block.SCDown
|
## ::: ultralytics.nn.modules.block.SCDown
|
||||||
|
|
||||||
|
<br><br><hr><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.nn.modules.block.TorchVision
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,10 @@ keywords: Ultralytics, convolution modules, Conv, LightConv, GhostConv, YOLO, de
|
||||||
|
|
||||||
<br><br><hr><br>
|
<br><br><hr><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.nn.modules.conv.Index
|
||||||
|
|
||||||
|
<br><br><hr><br>
|
||||||
|
|
||||||
## ::: ultralytics.nn.modules.conv.autopad
|
## ::: ultralytics.nn.modules.conv.autopad
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.58"
|
__version__ = "8.3.59"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
21
ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml
Normal file
21
ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
# YOLO11-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
nc: 10 # number of classes
|
||||||
|
scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n'
|
||||||
|
# [depth, width, max_channels]
|
||||||
|
n: [0.33, 0.25, 1024]
|
||||||
|
s: [0.33, 0.50, 1024]
|
||||||
|
m: [0.67, 0.75, 1024]
|
||||||
|
l: [1.00, 1.00, 1024]
|
||||||
|
x: [1.00, 1.25, 1024]
|
||||||
|
|
||||||
|
# YOLO11n backbone
|
||||||
|
backbone:
|
||||||
|
# [from, repeats, module, args]
|
||||||
|
- [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end
|
||||||
|
|
||||||
|
# YOLO11n head
|
||||||
|
head:
|
||||||
|
- [-1, 1, Classify, [nc]] # Classify
|
||||||
|
|
@ -56,6 +56,7 @@ from .block import (
|
||||||
RepVGGDW,
|
RepVGGDW,
|
||||||
ResNetLayer,
|
ResNetLayer,
|
||||||
SCDown,
|
SCDown,
|
||||||
|
TorchVision,
|
||||||
)
|
)
|
||||||
from .conv import (
|
from .conv import (
|
||||||
CBAM,
|
CBAM,
|
||||||
|
|
@ -68,6 +69,7 @@ from .conv import (
|
||||||
DWConvTranspose2d,
|
DWConvTranspose2d,
|
||||||
Focus,
|
Focus,
|
||||||
GhostConv,
|
GhostConv,
|
||||||
|
Index,
|
||||||
LightConv,
|
LightConv,
|
||||||
RepConv,
|
RepConv,
|
||||||
SpatialAttention,
|
SpatialAttention,
|
||||||
|
|
@ -156,4 +158,6 @@ __all__ = (
|
||||||
"C2fCIB",
|
"C2fCIB",
|
||||||
"Attention",
|
"Attention",
|
||||||
"PSA",
|
"PSA",
|
||||||
|
"TorchVision",
|
||||||
|
"Index",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,7 @@ __all__ = (
|
||||||
"Attention",
|
"Attention",
|
||||||
"PSA",
|
"PSA",
|
||||||
"SCDown",
|
"SCDown",
|
||||||
|
"TorchVision",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
||||||
return self.cv2(self.cv1(x))
|
return self.cv2(self.cv1(x))
|
||||||
|
|
||||||
|
|
||||||
|
class TorchVision(nn.Module):
|
||||||
|
"""
|
||||||
|
TorchVision module to allow loading any torchvision model.
|
||||||
|
|
||||||
|
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c1 (int): Input channels.
|
||||||
|
c2 (): Output channels.
|
||||||
|
model (str): Name of the torchvision model to load.
|
||||||
|
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
||||||
|
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
||||||
|
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
|
||||||
|
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
||||||
|
"""Load the model and weights from torchvision."""
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
if hasattr(torchvision.models, "get_model"):
|
||||||
|
self.m = torchvision.models.get_model(model, weights=weights)
|
||||||
|
else:
|
||||||
|
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
|
||||||
|
if unwrap:
|
||||||
|
layers = list(self.m.children())[:-truncate]
|
||||||
|
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
|
||||||
|
layers = [*list(layers[0].children()), *layers[1:]]
|
||||||
|
self.m = nn.Sequential(*layers)
|
||||||
|
self.split = split
|
||||||
|
else:
|
||||||
|
self.split = False
|
||||||
|
self.m.head = self.m.heads = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass through the model."""
|
||||||
|
if self.split:
|
||||||
|
y = [x]
|
||||||
|
y.extend(m(y[-1]) for m in self.m)
|
||||||
|
else:
|
||||||
|
y = self.m(x)
|
||||||
|
return y
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ __all__ = (
|
||||||
"CBAM",
|
"CBAM",
|
||||||
"Concat",
|
"Concat",
|
||||||
"RepConv",
|
"RepConv",
|
||||||
|
"Index",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -330,3 +331,20 @@ class Concat(nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Forward pass for the YOLOv8 mask Proto module."""
|
"""Forward pass for the YOLOv8 mask Proto module."""
|
||||||
return torch.cat(x, self.d)
|
return torch.cat(x, self.d)
|
||||||
|
|
||||||
|
|
||||||
|
class Index(nn.Module):
|
||||||
|
"""Returns a particular index of the input."""
|
||||||
|
|
||||||
|
def __init__(self, c1, c2, index=0):
|
||||||
|
"""Returns a particular index of the input."""
|
||||||
|
super().__init__()
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass.
|
||||||
|
|
||||||
|
Expects a list of tensors as input.
|
||||||
|
"""
|
||||||
|
return x[self.index]
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ from ultralytics.nn.modules import (
|
||||||
HGBlock,
|
HGBlock,
|
||||||
HGStem,
|
HGStem,
|
||||||
ImagePoolingAttn,
|
ImagePoolingAttn,
|
||||||
|
Index,
|
||||||
Pose,
|
Pose,
|
||||||
RepC3,
|
RepC3,
|
||||||
RepConv,
|
RepConv,
|
||||||
|
|
@ -59,6 +60,7 @@ from ultralytics.nn.modules import (
|
||||||
RTDETRDecoder,
|
RTDETRDecoder,
|
||||||
SCDown,
|
SCDown,
|
||||||
Segment,
|
Segment,
|
||||||
|
TorchVision,
|
||||||
WorldDetect,
|
WorldDetect,
|
||||||
v10Detect,
|
v10Detect,
|
||||||
)
|
)
|
||||||
|
|
@ -1052,7 +1054,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
m.legacy = legacy
|
m.legacy = legacy
|
||||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||||
args.insert(1, [ch[x] for x in f])
|
args.insert(1, [ch[x] for x in f])
|
||||||
elif m is CBLinear:
|
elif m in {CBLinear, TorchVision, Index}:
|
||||||
c2 = args[0]
|
c2 = args[0]
|
||||||
c1 = ch[f]
|
c1 = ch[f]
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue