Cleanup TorchVision related functions (#18790)

This commit is contained in:
Mohammed Yasin 2025-01-21 18:41:05 +08:00 committed by GitHub
parent 066c5443f5
commit 5306a8cc1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 8 additions and 13 deletions

View file

@ -1120,8 +1120,6 @@ class TorchVision(nn.Module):
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
Args:
c1 (int): Input channels.
c2 (): Output channels.
model (str): Name of the torchvision model to load.
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
@ -1129,7 +1127,7 @@ class TorchVision(nn.Module):
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
"""
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
def __init__(self, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
"""Load the model and weights from torchvision."""
import torchvision # scope for faster 'import ultralytics'

View file

@ -336,7 +336,7 @@ class Concat(nn.Module):
class Index(nn.Module):
"""Returns a particular index of the input."""
def __init__(self, c1, c2, index=0):
def __init__(self, index=0):
"""Returns a particular index of the input."""
super().__init__()
self.index = index

View file

@ -1060,12 +1060,16 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m in frozenset({CBLinear, TorchVision, Index}):
elif m is CBLinear:
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]
elif m is CBFuse:
c2 = ch[f[-1]]
elif m in frozenset({TorchVision, Index}):
c2 = args[0]
c1 = ch[f]
args = [*args[1:]]
else:
c2 = ch[f]