ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
246c3eca81
commit
cc1e77138c
8 changed files with 104 additions and 2 deletions
|
|
@ -56,6 +56,7 @@ from .block import (
|
|||
RepVGGDW,
|
||||
ResNetLayer,
|
||||
SCDown,
|
||||
TorchVision,
|
||||
)
|
||||
from .conv import (
|
||||
CBAM,
|
||||
|
|
@ -68,6 +69,7 @@ from .conv import (
|
|||
DWConvTranspose2d,
|
||||
Focus,
|
||||
GhostConv,
|
||||
Index,
|
||||
LightConv,
|
||||
RepConv,
|
||||
SpatialAttention,
|
||||
|
|
@ -156,4 +158,6 @@ __all__ = (
|
|||
"C2fCIB",
|
||||
"Attention",
|
||||
"PSA",
|
||||
"TorchVision",
|
||||
"Index",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ __all__ = (
|
|||
"Attention",
|
||||
"PSA",
|
||||
"SCDown",
|
||||
"TorchVision",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class TorchVision(nn.Module):
|
||||
"""
|
||||
TorchVision module to allow loading any torchvision model.
|
||||
|
||||
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
|
||||
|
||||
Attributes:
|
||||
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
c2 (): Output channels.
|
||||
model (str): Name of the torchvision model to load.
|
||||
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
||||
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
||||
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
|
||||
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
||||
"""Load the model and weights from torchvision."""
|
||||
import torchvision
|
||||
|
||||
super().__init__()
|
||||
if hasattr(torchvision.models, "get_model"):
|
||||
self.m = torchvision.models.get_model(model, weights=weights)
|
||||
else:
|
||||
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
|
||||
if unwrap:
|
||||
layers = list(self.m.children())[:-truncate]
|
||||
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
|
||||
layers = [*list(layers[0].children()), *layers[1:]]
|
||||
self.m = nn.Sequential(*layers)
|
||||
self.split = split
|
||||
else:
|
||||
self.split = False
|
||||
self.m.head = self.m.heads = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the model."""
|
||||
if self.split:
|
||||
y = [x]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
else:
|
||||
y = self.m(x)
|
||||
return y
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ __all__ = (
|
|||
"CBAM",
|
||||
"Concat",
|
||||
"RepConv",
|
||||
"Index",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -330,3 +331,20 @@ class Concat(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Forward pass for the YOLOv8 mask Proto module."""
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class Index(nn.Module):
|
||||
"""Returns a particular index of the input."""
|
||||
|
||||
def __init__(self, c1, c2, index=0):
|
||||
"""Returns a particular index of the input."""
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Expects a list of tensors as input.
|
||||
"""
|
||||
return x[self.index]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue