ultralytics 8.3.59 Add ability to load any torchvision model as module (#18564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
246c3eca81
commit
cc1e77138c
8 changed files with 104 additions and 2 deletions
|
|
@ -21,6 +21,7 @@ __all__ = (
|
|||
"CBAM",
|
||||
"Concat",
|
||||
"RepConv",
|
||||
"Index",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -330,3 +331,20 @@ class Concat(nn.Module):
|
|||
def forward(self, x):
|
||||
"""Forward pass for the YOLOv8 mask Proto module."""
|
||||
return torch.cat(x, self.d)
|
||||
|
||||
|
||||
class Index(nn.Module):
|
||||
"""Returns a particular index of the input."""
|
||||
|
||||
def __init__(self, c1, c2, index=0):
|
||||
"""Returns a particular index of the input."""
|
||||
super().__init__()
|
||||
self.index = index
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Expects a list of tensors as input.
|
||||
"""
|
||||
return x[self.index]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue