ultralytics 8.3.64 new torchvision.ops access in model YAMLs (#18680)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2025-01-20 06:44:56 +08:00 committed by GitHub
parent a8e2464a9c
commit 673b43ce17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 11 additions and 5 deletions

View file

@ -1139,10 +1139,10 @@ class TorchVision(nn.Module):
else:
self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
if unwrap:
layers = list(self.m.children())[:-truncate]
layers = list(self.m.children())
if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin
layers = [*list(layers[0].children()), *layers[1:]]
self.m = nn.Sequential(*layers)
self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
self.split = split
else:
self.split = False

View file

@ -955,7 +955,13 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):