ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2025-01-09 20:57:46 +08:00 committed by GitHub
parent 246c3eca81
commit cc1e77138c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 104 additions and 2 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.58"
__version__ = "8.3.59"
import os

View file

@ -0,0 +1,21 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11-cls image classification model. For Usage examples see https://docs.ultralytics.com/tasks/classify
# Parameters
nc: 10 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024]
s: [0.33, 0.50, 1024]
m: [0.67, 0.75, 1024]
l: [1.00, 1.00, 1024]
x: [1.00, 1.25, 1024]
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end
# YOLO11n head
head:
- [-1, 1, Classify, [nc]] # Classify

View file

@ -56,6 +56,7 @@ from .block import (
RepVGGDW,
ResNetLayer,
SCDown,
TorchVision,
)
from .conv import (
CBAM,
@ -68,6 +69,7 @@ from .conv import (
DWConvTranspose2d,
Focus,
GhostConv,
Index,
LightConv,
RepConv,
SpatialAttention,
@ -156,4 +158,6 @@ __all__ = (
"C2fCIB",
"Attention",
"PSA",
"TorchVision",
"Index",
)

View file

@ -49,6 +49,7 @@ __all__ = (
"Attention",
"PSA",
"SCDown",
"TorchVision",
)
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
def forward(self, x):
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
return self.cv2(self.cv1(x))
class TorchVision(nn.Module):
"""
TorchVision module to allow loading any torchvision model.
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
Attributes:
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
Args:
c1 (int): Input channels.
c2 (): Output channels.
model (str): Name of the torchvision model to load.
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
"""
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
"""Load the model and weights from torchvision."""
import torchvision
super().__init__()
if hasattr(torchvision.models, "get_model"):
self.m = torchvision.models.get_model(model, weights=weights)
else:
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
if unwrap:
layers = list(self.m.children())[:-truncate]
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
layers = [*list(layers[0].children()), *layers[1:]]
self.m = nn.Sequential(*layers)
self.split = split
else:
self.split = False
self.m.head = self.m.heads = nn.Identity()
def forward(self, x):
"""Forward pass through the model."""
if self.split:
y = [x]
y.extend(m(y[-1]) for m in self.m)
else:
y = self.m(x)
return y

View file

@ -21,6 +21,7 @@ __all__ = (
"CBAM",
"Concat",
"RepConv",
"Index",
)
@ -330,3 +331,20 @@ class Concat(nn.Module):
def forward(self, x):
"""Forward pass for the YOLOv8 mask Proto module."""
return torch.cat(x, self.d)
class Index(nn.Module):
"""Returns a particular index of the input."""
def __init__(self, c1, c2, index=0):
"""Returns a particular index of the input."""
super().__init__()
self.index = index
def forward(self, x):
"""
Forward pass.
Expects a list of tensors as input.
"""
return x[self.index]

View file

@ -50,6 +50,7 @@ from ultralytics.nn.modules import (
HGBlock,
HGStem,
ImagePoolingAttn,
Index,
Pose,
RepC3,
RepConv,
@ -59,6 +60,7 @@ from ultralytics.nn.modules import (
RTDETRDecoder,
SCDown,
Segment,
TorchVision,
WorldDetect,
v10Detect,
)
@ -1052,7 +1054,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m is CBLinear:
elif m in {CBLinear, TorchVision, Index}:
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]