Use frozenset() (#18785)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-21 01:38:00 +01:00 committed by GitHub
parent fb3e5adfd7
commit 9341c1df76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 131 deletions

View file

@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
base_modules = frozenset(
{
Classify,
Conv,
ConvTranspose,
@ -1001,33 +989,49 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
PSA,
SCDown,
C2fCIB,
}:
}
)
repeat_modules = frozenset( # modules with 'repeat' arguments
{
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
C2fAttn,
C3,
C3TR,
C3Ghost,
C3x,
RepC3,
C2fPSA,
C2fCIB,
C2PSA,
}
)
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in base_modules:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn:
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
args[2] = int(
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
if m is C2fAttn: # set 1) embed channels and 2) num heads
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
args = [c1, c2, *args[1:]]
if m in {
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
C2fAttn,
C3,
C3TR,
C3Ghost,
C3x,
RepC3,
C2fPSA,
C2fCIB,
C2PSA,
}:
if m in repeat_modules:
args.insert(2, n) # number of repeats
n = 1
if m is C3k2: # for M/L/X sizes
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[3] = True
elif m is AIFI:
args = [ch[f], *args]
elif m in {HGStem, HGBlock}:
elif m in frozenset({HGStem, HGBlock}):
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m in {CBLinear, TorchVision, Index}:
elif m in frozenset({CBLinear, TorchVision, Index}):
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]