ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2025-01-09 20:57:46 +08:00 committed by GitHub
parent 246c3eca81
commit cc1e77138c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 104 additions and 2 deletions

View file

@ -49,6 +49,7 @@ __all__ = (
"Attention",
"PSA",
"SCDown",
"TorchVision",
)
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
def forward(self, x):
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
return self.cv2(self.cv1(x))
class TorchVision(nn.Module):
"""
TorchVision module to allow loading any torchvision model.
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
Attributes:
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
Args:
c1 (int): Input channels.
c2 (): Output channels.
model (str): Name of the torchvision model to load.
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
"""
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
"""Load the model and weights from torchvision."""
import torchvision
super().__init__()
if hasattr(torchvision.models, "get_model"):
self.m = torchvision.models.get_model(model, weights=weights)
else:
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
if unwrap:
layers = list(self.m.children())[:-truncate]
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
layers = [*list(layers[0].children()), *layers[1:]]
self.m = nn.Sequential(*layers)
self.split = split
else:
self.split = False
self.m.head = self.m.heads = nn.Identity()
def forward(self, x):
"""Forward pass through the model."""
if self.split:
y = [x]
y.extend(m(y[-1]) for m in self.m)
else:
y = self.m(x)
return y