Use frozenset() (#18785)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-21 01:38:00 +01:00 committed by GitHub
parent fb3e5adfd7
commit 9341c1df76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 131 deletions

View file

@ -47,8 +47,8 @@ SOLUTION_MAP = {
}
# Define valid tasks and modes
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
TASKS = {"detect", "segment", "classify", "pose", "obb"}
MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"})
TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"})
TASK2DATA = {
"detect": "coco8.yaml",
"segment": "coco8-seg.yaml",
@ -70,7 +70,7 @@ TASK2METRIC = {
"pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(B)",
}
MODELS = {TASK2MODEL[task] for task in TASKS}
MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
SOLUTIONS_HELP_MSG = f"""
@ -144,7 +144,8 @@ CLI_HELP_MSG = f"""
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
CFG_FLOAT_KEYS = frozenset(
{ # integer or float arguments, i.e. x=2 and x=2.0
"warmup_epochs",
"box",
"cls",
@ -155,7 +156,9 @@ CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
"workspace",
"batch",
}
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
)
CFG_FRACTION_KEYS = frozenset(
{ # fractional float arguments with 0.0<=values<=1.0
"dropout",
"lr0",
"lrf",
@ -179,7 +182,9 @@ CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
"iou",
"fraction",
}
CFG_INT_KEYS = { # integer-only arguments
)
CFG_INT_KEYS = frozenset(
{ # integer-only arguments
"epochs",
"patience",
"workers",
@ -192,7 +197,9 @@ CFG_INT_KEYS = { # integer-only arguments
"nbs",
"save_period",
}
CFG_BOOL_KEYS = { # boolean-only arguments
)
CFG_BOOL_KEYS = frozenset(
{ # boolean-only arguments
"save",
"exist_ok",
"verbose",
@ -228,6 +235,7 @@ CFG_BOOL_KEYS = { # boolean-only arguments
"profile",
"multi_scale",
}
)
def cfg2dict(cfg):
@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
- Prints detailed error messages for each mismatched key to help users correct their configurations.
"""
custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))
if mismatched := [k for k in custom_keys if k not in base_keys]:
from difflib import get_close_matches

View file

@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
base_modules = frozenset(
{
Classify,
Conv,
ConvTranspose,
@ -1001,18 +989,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
PSA,
SCDown,
C2fCIB,
}:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn:
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
args[2] = int(
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
args = [c1, c2, *args[1:]]
if m in {
}
)
repeat_modules = frozenset( # modules with 'repeat' arguments
{
BottleneckCSP,
C1,
C2,
@ -1027,7 +1007,31 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C2fPSA,
C2fCIB,
C2PSA,
}:
}
)
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in base_modules:
c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn: # set 1) embed channels and 2) num heads
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
args = [c1, c2, *args[1:]]
if m in repeat_modules:
args.insert(2, n) # number of repeats
n = 1
if m is C3k2: # for M/L/X sizes
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[3] = True
elif m is AIFI:
args = [ch[f], *args]
elif m in {HGStem, HGBlock}:
elif m in frozenset({HGStem, HGBlock}):
c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]]
if m is HGBlock:
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f])
elif m in {CBLinear, TorchVision, Index}:
elif m in frozenset({CBLinear, TorchVision, Index}):
c2 = args[0]
c1 = ch[f]
args = [c1, c2, *args[1:]]

View file

@ -55,8 +55,8 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
unmatched_a = list(np.arange(cost_matrix.shape[0]))
unmatched_b = list(np.arange(cost_matrix.shape[1]))
else:
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1]))
unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
return matches, unmatched_a, unmatched_b

View file

@ -1227,7 +1227,7 @@ class SettingsManager(JSONDict):
def _validate_settings(self):
"""Validate the current settings and reset if necessary."""
correct_keys = set(self.keys()) == set(self.defaults.keys())
correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())
correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
correct_version = self.get("settings_version", "") == self.version

View file

@ -407,7 +407,7 @@ class Instances:
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
seg_len = [b.segments.shape[1] for b in instances_list]
if len(set(seg_len)) > 1: # resample segments if there's different length
if len(frozenset(seg_len)) > 1: # resample segments if there's different length
max_len = max(seg_len)
cat_segments = np.concatenate(
[