Cleanup TorchVision related functions (#18790)
This commit is contained in:
parent
066c5443f5
commit
5306a8cc1f
4 changed files with 8 additions and 13 deletions
|
|
@ -6,18 +6,11 @@
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 10 # number of classes
|
nc: 10 # number of classes
|
||||||
scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n'
|
|
||||||
# [depth, width, max_channels]
|
|
||||||
n: [0.33, 0.25, 1024]
|
|
||||||
s: [0.33, 0.50, 1024]
|
|
||||||
m: [0.67, 0.75, 1024]
|
|
||||||
l: [1.00, 1.00, 1024]
|
|
||||||
x: [1.00, 1.25, 1024]
|
|
||||||
|
|
||||||
# ResNet18 backbone
|
# ResNet18 backbone
|
||||||
backbone:
|
backbone:
|
||||||
# [from, repeats, module, args]
|
# [from, repeats, module, args]
|
||||||
- [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end
|
- [-1, 1, TorchVision, [512, resnet18, DEFAULT, True, 2]] # truncate two layers from the end
|
||||||
|
|
||||||
# YOLO11n head
|
# YOLO11n head
|
||||||
head:
|
head:
|
||||||
|
|
|
||||||
|
|
@ -1120,8 +1120,6 @@ class TorchVision(nn.Module):
|
||||||
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
c1 (int): Input channels.
|
|
||||||
c2 (): Output channels.
|
|
||||||
model (str): Name of the torchvision model to load.
|
model (str): Name of the torchvision model to load.
|
||||||
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
|
||||||
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
|
||||||
|
|
@ -1129,7 +1127,7 @@ class TorchVision(nn.Module):
|
||||||
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
split (bool, optional): Returns output from intermediate child modules as list. Default is False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
def __init__(self, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
|
||||||
"""Load the model and weights from torchvision."""
|
"""Load the model and weights from torchvision."""
|
||||||
import torchvision # scope for faster 'import ultralytics'
|
import torchvision # scope for faster 'import ultralytics'
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -336,7 +336,7 @@ class Concat(nn.Module):
|
||||||
class Index(nn.Module):
|
class Index(nn.Module):
|
||||||
"""Returns a particular index of the input."""
|
"""Returns a particular index of the input."""
|
||||||
|
|
||||||
def __init__(self, c1, c2, index=0):
|
def __init__(self, index=0):
|
||||||
"""Returns a particular index of the input."""
|
"""Returns a particular index of the input."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
|
||||||
|
|
@ -1060,12 +1060,16 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
m.legacy = legacy
|
m.legacy = legacy
|
||||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||||
args.insert(1, [ch[x] for x in f])
|
args.insert(1, [ch[x] for x in f])
|
||||||
elif m in frozenset({CBLinear, TorchVision, Index}):
|
elif m is CBLinear:
|
||||||
c2 = args[0]
|
c2 = args[0]
|
||||||
c1 = ch[f]
|
c1 = ch[f]
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
elif m is CBFuse:
|
elif m is CBFuse:
|
||||||
c2 = ch[f[-1]]
|
c2 = ch[f[-1]]
|
||||||
|
elif m in frozenset({TorchVision, Index}):
|
||||||
|
c2 = args[0]
|
||||||
|
c1 = ch[f]
|
||||||
|
args = [*args[1:]]
|
||||||
else:
|
else:
|
||||||
c2 = ch[f]
|
c2 = ch[f]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue