Use frozenset() (#18785)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-21 01:38:00 +01:00 committed by GitHub
parent fb3e5adfd7
commit 9341c1df76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 131 deletions

View file

@ -47,8 +47,8 @@ SOLUTION_MAP = {
} }
# Define valid tasks and modes # Define valid tasks and modes
MODES = {"train", "val", "predict", "export", "track", "benchmark"} MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"})
TASKS = {"detect", "segment", "classify", "pose", "obb"} TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"})
TASK2DATA = { TASK2DATA = {
"detect": "coco8.yaml", "detect": "coco8.yaml",
"segment": "coco8-seg.yaml", "segment": "coco8-seg.yaml",
@ -70,7 +70,7 @@ TASK2METRIC = {
"pose": "metrics/mAP50-95(P)", "pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(B)", "obb": "metrics/mAP50-95(B)",
} }
MODELS = {TASK2MODEL[task] for task in TASKS} MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
SOLUTIONS_HELP_MSG = f""" SOLUTIONS_HELP_MSG = f"""
@ -144,7 +144,8 @@ CLI_HELP_MSG = f"""
""" """
# Define keys for arg type checks # Define keys for arg type checks
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 CFG_FLOAT_KEYS = frozenset(
{ # integer or float arguments, i.e. x=2 and x=2.0
"warmup_epochs", "warmup_epochs",
"box", "box",
"cls", "cls",
@ -154,8 +155,10 @@ CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
"time", "time",
"workspace", "workspace",
"batch", "batch",
} }
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 )
CFG_FRACTION_KEYS = frozenset(
{ # fractional float arguments with 0.0<=values<=1.0
"dropout", "dropout",
"lr0", "lr0",
"lrf", "lrf",
@ -178,8 +181,10 @@ CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
"conf", "conf",
"iou", "iou",
"fraction", "fraction",
} }
CFG_INT_KEYS = { # integer-only arguments )
CFG_INT_KEYS = frozenset(
{ # integer-only arguments
"epochs", "epochs",
"patience", "patience",
"workers", "workers",
@ -191,8 +196,10 @@ CFG_INT_KEYS = { # integer-only arguments
"line_width", "line_width",
"nbs", "nbs",
"save_period", "save_period",
} }
CFG_BOOL_KEYS = { # boolean-only arguments )
CFG_BOOL_KEYS = frozenset(
{ # boolean-only arguments
"save", "save",
"exist_ok", "exist_ok",
"verbose", "verbose",
@ -227,7 +234,8 @@ CFG_BOOL_KEYS = { # boolean-only arguments
"nms", "nms",
"profile", "profile",
"multi_scale", "multi_scale",
} }
)
def cfg2dict(cfg): def cfg2dict(cfg):
@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
- Prints detailed error messages for each mismatched key to help users correct their configurations. - Prints detailed error messages for each mismatched key to help users correct their configurations.
""" """
custom = _handle_deprecation(custom) custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))
if mismatched := [k for k in custom_keys if k not in base_keys]: if mismatched := [k for k in custom_keys if k not in base_keys]:
from difflib import get_close_matches from difflib import get_close_matches

View file

@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch] ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args base_modules = frozenset(
m = ( {
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
Classify, Classify,
Conv, Conv,
ConvTranspose, ConvTranspose,
@ -1001,18 +989,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
PSA, PSA,
SCDown, SCDown,
C2fCIB, C2fCIB,
}: }
c1, c2 = ch[f], args[0] )
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) repeat_modules = frozenset( # modules with 'repeat' arguments
c2 = make_divisible(min(c2, max_channels) * width, 8) {
if m is C2fAttn:
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
args[2] = int(
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
args = [c1, c2, *args[1:]]
if m in {
BottleneckCSP, BottleneckCSP,
C1, C1,
C2, C2,
@ -1027,7 +1007,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C2fPSA, C2fPSA,
C2fCIB, C2fCIB,
C2PSA, C2PSA,
}: }
)
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in base_modules:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn: # set 1) embed channels and 2) num heads
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
args = [c1, c2, *args[1:]]
if m in repeat_modules:
args.insert(2, n) # number of repeats args.insert(2, n) # number of repeats
n = 1 n = 1
if m is C3k2: # for M/L/X sizes if m is C3k2: # for M/L/X sizes
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[3] = True args[3] = True
elif m is AIFI: elif m is AIFI:
args = [ch[f], *args] args = [ch[f], *args]
elif m in {HGStem, HGBlock}: elif m in frozenset({HGStem, HGBlock}):
c1, cm, c2 = ch[f], args[0], args[1] c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]] args = [c1, cm, c2, *args[2:]]
if m is HGBlock: if m is HGBlock:
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]] args = [ch[f]]
elif m is Concat: elif m is Concat:
c2 = sum(ch[x] for x in f) c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
args.append([ch[x] for x in f]) args.append([ch[x] for x in f])
if m is Segment: if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8) args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f]) args.insert(1, [ch[x] for x in f])
elif m in {CBLinear, TorchVision, Index}: elif m in frozenset({CBLinear, TorchVision, Index}):
c2 = args[0] c2 = args[0]
c1 = ch[f] c1 = ch[f]
args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]

View file

@ -55,8 +55,8 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
unmatched_a = list(np.arange(cost_matrix.shape[0])) unmatched_a = list(np.arange(cost_matrix.shape[0]))
unmatched_b = list(np.arange(cost_matrix.shape[1])) unmatched_b = list(np.arange(cost_matrix.shape[1]))
else: else:
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
return matches, unmatched_a, unmatched_b return matches, unmatched_a, unmatched_b

View file

@ -1227,7 +1227,7 @@ class SettingsManager(JSONDict):
def _validate_settings(self): def _validate_settings(self):
"""Validate the current settings and reset if necessary.""" """Validate the current settings and reset if necessary."""
correct_keys = set(self.keys()) == set(self.defaults.keys()) correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())
correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
correct_version = self.get("settings_version", "") == self.version correct_version = self.get("settings_version", "") == self.version

View file

@ -407,7 +407,7 @@ class Instances:
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
seg_len = [b.segments.shape[1] for b in instances_list] seg_len = [b.segments.shape[1] for b in instances_list]
if len(set(seg_len)) > 1: # resample segments if there's different length if len(frozenset(seg_len)) > 1: # resample segments if there's different length
max_len = max(seg_len) max_len = max(seg_len)
cat_segments = np.concatenate( cat_segments = np.concatenate(
[ [