ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
246c3eca81
commit
cc1e77138c
8 changed files with 104 additions and 2 deletions
|
|
@ -49,6 +49,7 @@ __all__ = (
|
|||
"Attention",
|
||||
"PSA",
|
||||
"SCDown",
|
||||
"TorchVision",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1107,3 +1108,51 @@ class SCDown(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Applies convolution and downsampling to the input tensor in the SCDown module."""
|
||||
return self.cv2(self.cv1(x))
|
||||
|
||||
|
||||
class TorchVision(nn.Module):
|
||||
"""
|
||||
TorchVision module to allow loading any torchvision model.
|
||||
|
||||
This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.
|
||||
|
||||
Attributes:
|
||||
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
||||
|
||||
Args:
|
||||
c1 (int): Input channels.
|
||||
c2 (): Output channels.
|
||||
model (str): Name of the torchvision model to load.
|
||||
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
||||
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
||||
truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
|
||||
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
||||
"""Load the model and weights from torchvision."""
|
||||
import torchvision
|
||||
|
||||
super().__init__()
|
||||
if hasattr(torchvision.models, "get_model"):
|
||||
self.m = torchvision.models.get_model(model, weights=weights)
|
||||
else:
|
||||
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
|
||||
if unwrap:
|
||||
layers = list(self.m.children())[:-truncate]
|
||||
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
|
||||
layers = [*list(layers[0].children()), *layers[1:]]
|
||||
self.m = nn.Sequential(*layers)
|
||||
self.split = split
|
||||
else:
|
||||
self.split = False
|
||||
self.m.head = self.m.heads = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the model."""
|
||||
if self.split:
|
||||
y = [x]
|
||||
y.extend(m(y[-1]) for m in self.m)
|
||||
else:
|
||||
y = self.m(x)
|
||||
return y
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue