ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
246c3eca81
commit
cc1e77138c
8 changed files with 104 additions and 2 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.58"
|
||||
__version__ = "8.3.59"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
21
ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml
Normal file
21
ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
# YOLO11-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
|
||||
|
||||
# Parameters
|
||||
nc: 10 # number of classes
|
||||
scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n'
|
||||
# [depth, width, max_channels]
|
||||
n: [0.33, 0.25, 1024]
|
||||
s: [0.33, 0.50, 1024]
|
||||
m: [0.67, 0.75, 1024]
|
||||
l: [1.00, 1.00, 1024]
|
||||
x: [1.00, 1.25, 1024]
|
||||
|
||||
# YOLO11n backbone
|
||||
backbone:
|
||||
# [from, repeats, module, args]
|
||||
- [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end
|
||||
|
||||
# YOLO11n head
|
||||
head:
|
||||
- [-1, 1, Classify, [nc]] # Classify
|
||||
|
|
@ -56,6 +56,7 @@ from .block import (
|
|||
RepVGGDW,
|
||||
ResNetLayer,
|
||||
SCDown,
|
||||
TorchVision,
|
||||
)
|
||||
from .conv import (
|
||||
CBAM,
|
||||
|
|
@ -68,6 +69,7 @@ from .conv import (
|
|||
DWConvTranspose2d,
|
||||
Focus,
|
||||
GhostConv,
|
||||
Index,
|
||||
LightConv,
|
||||
RepConv,
|
||||
SpatialAttention,
|
||||
|
|
@ -156,4 +158,6 @@ __all__ = (
|
|||
"C2fCIB",
|
||||
"Attention",
|
||||
"PSA",
|
||||
"TorchVision",
|
||||
"Index",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ __all__ = (
|
|||
"Attention",
|
||||
"PSA",
|
||||
"SCDown",
|
||||
"TorchVision",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class TorchVision(nn.Module):
|
||||
"""
|
||||
TorchVision module to allow loading any torchvision model.
|
||||
|
||||
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
|
||||
|
||||
Attributes:
|
||||
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
c2 (): Output channels.
|
||||
model (str): Name of the torchvision model to load.
|
||||
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
||||
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
||||
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
|
||||
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
||||
"""Load the model and weights from torchvision."""
|
||||
import torchvision
|
||||
|
||||
super().__init__()
|
||||
if hasattr(torchvision.models, "get_model"):
|
||||
self.m = torchvision.models.get_model(model, weights=weights)
|
||||
else:
|
||||
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
|
||||
if unwrap:
|
||||
layers = list(self.m.children())[:-truncate]
|
||||
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
|
||||
layers = [*list(layers[0].children()), *layers[1:]]
|
||||
self.m = nn.Sequential(*layers)
|
||||
self.split = split
|
||||
else:
|
||||
self.split = False
|
||||
self.m.head = self.m.heads = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the model."""
|
||||
if self.split:
|
||||
y = [x]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
else:
|
||||
y = self.m(x)
|
||||
return y
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ __all__ = (
|
|||
"CBAM",
|
||||
"Concat",
|
||||
"RepConv",
|
||||
"Index",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -330,3 +331,20 @@ class Concat(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Forward pass for the YOLOv8 mask Proto module."""
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class Index(nn.Module):
|
||||
"""Returns a particular index of the input."""
|
||||
|
||||
def __init__(self, c1, c2, index=0):
|
||||
"""Returns a particular index of the input."""
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Expects a list of tensors as input.
|
||||
"""
|
||||
return x[self.index]
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ from ultralytics.nn.modules import (
|
|||
HGBlock,
|
||||
HGStem,
|
||||
ImagePoolingAttn,
|
||||
Index,
|
||||
Pose,
|
||||
RepC3,
|
||||
RepConv,
|
||||
|
|
@ -59,6 +60,7 @@ from ultralytics.nn.modules import (
|
|||
RTDETRDecoder,
|
||||
SCDown,
|
||||
Segment,
|
||||
TorchVision,
|
||||
WorldDetect,
|
||||
v10Detect,
|
||||
)
|
||||
|
|
@ -1052,7 +1054,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
m.legacy = legacy
|
||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||
args.insert(1, [ch[x] for x in f])
|
||||
elif m is CBLinear:
|
||||
elif m in {CBLinear, TorchVision, Index}:
|
||||
c2 = args[0]
|
||||
c1 = ch[f]
|
||||
args = [c1, c2, *args[1:]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue