Use frozenset() (#18785)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
fb3e5adfd7
commit
9341c1df76
5 changed files with 143 additions and 131 deletions
|
|
@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||
ch = [ch]
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||
m = (
|
||||
getattr(torch.nn, m[3:])
|
||||
if "nn." in m
|
||||
else getattr(__import__("torchvision").ops, m[16:])
|
||||
if "torchvision.ops." in m
|
||||
else globals()[m]
|
||||
) # get module
|
||||
for j, a in enumerate(args):
|
||||
if isinstance(a, str):
|
||||
with contextlib.suppress(ValueError):
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
base_modules = frozenset(
|
||||
{
|
||||
Classify,
|
||||
Conv,
|
||||
ConvTranspose,
|
||||
|
|
@ -1001,33 +989,49 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
PSA,
|
||||
SCDown,
|
||||
C2fCIB,
|
||||
}:
|
||||
}
|
||||
)
|
||||
repeat_modules = frozenset( # modules with 'repeat' arguments
|
||||
{
|
||||
BottleneckCSP,
|
||||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C3k2,
|
||||
C2fAttn,
|
||||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
RepC3,
|
||||
C2fPSA,
|
||||
C2fCIB,
|
||||
C2PSA,
|
||||
}
|
||||
)
|
||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||
m = (
|
||||
getattr(torch.nn, m[3:])
|
||||
if "nn." in m
|
||||
else getattr(__import__("torchvision").ops, m[16:])
|
||||
if "torchvision.ops." in m
|
||||
else globals()[m]
|
||||
) # get module
|
||||
for j, a in enumerate(args):
|
||||
if isinstance(a, str):
|
||||
with contextlib.suppress(ValueError):
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in base_modules:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||
if m is C2fAttn:
|
||||
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
|
||||
args[2] = int(
|
||||
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
|
||||
) # num heads
|
||||
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
||||
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
||||
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in {
|
||||
BottleneckCSP,
|
||||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C3k2,
|
||||
C2fAttn,
|
||||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
RepC3,
|
||||
C2fPSA,
|
||||
C2fCIB,
|
||||
C2PSA,
|
||||
}:
|
||||
if m in repeat_modules:
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
if m is C3k2: # for M/L/X sizes
|
||||
|
|
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args[3] = True
|
||||
elif m is AIFI:
|
||||
args = [ch[f], *args]
|
||||
elif m in {HGStem, HGBlock}:
|
||||
elif m in frozenset({HGStem, HGBlock}):
|
||||
c1, cm, c2 = ch[f], args[0], args[1]
|
||||
args = [c1, cm, c2, *args[2:]]
|
||||
if m is HGBlock:
|
||||
|
|
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
|
||||
elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
|
|
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
m.legacy = legacy
|
||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||
args.insert(1, [ch[x] for x in f])
|
||||
elif m in {CBLinear, TorchVision, Index}:
|
||||
elif m in frozenset({CBLinear, TorchVision, Index}):
|
||||
c2 = args[0]
|
||||
c1 = ch[f]
|
||||
args = [c1, c2, *args[1:]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue