Use frozenset() (#18785)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-21 01:38:00 +01:00 committed by GitHub
parent fb3e5adfd7
commit 9341c1df76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 131 deletions

View file

@ -47,8 +47,8 @@ SOLUTION_MAP = {
} }
# Define valid tasks and modes # Define valid tasks and modes
MODES = {"train", "val", "predict", "export", "track", "benchmark"} MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"})
TASKS = {"detect", "segment", "classify", "pose", "obb"} TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"})
TASK2DATA = { TASK2DATA = {
"detect": "coco8.yaml", "detect": "coco8.yaml",
"segment": "coco8-seg.yaml", "segment": "coco8-seg.yaml",
@ -70,7 +70,7 @@ TASK2METRIC = {
"pose": "metrics/mAP50-95(P)", "pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(B)", "obb": "metrics/mAP50-95(B)",
} }
MODELS = {TASK2MODEL[task] for task in TASKS} MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
SOLUTIONS_HELP_MSG = f""" SOLUTIONS_HELP_MSG = f"""
@ -144,90 +144,98 @@ CLI_HELP_MSG = f"""
""" """
# Define keys for arg type checks # Define keys for arg type checks
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 CFG_FLOAT_KEYS = frozenset(
"warmup_epochs", { # integer or float arguments, i.e. x=2 and x=2.0
"box", "warmup_epochs",
"cls", "box",
"dfl", "cls",
"degrees", "dfl",
"shear", "degrees",
"time", "shear",
"workspace", "time",
"batch", "workspace",
} "batch",
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 }
"dropout", )
"lr0", CFG_FRACTION_KEYS = frozenset(
"lrf", { # fractional float arguments with 0.0<=values<=1.0
"momentum", "dropout",
"weight_decay", "lr0",
"warmup_momentum", "lrf",
"warmup_bias_lr", "momentum",
"hsv_h", "weight_decay",
"hsv_s", "warmup_momentum",
"hsv_v", "warmup_bias_lr",
"translate", "hsv_h",
"scale", "hsv_s",
"perspective", "hsv_v",
"flipud", "translate",
"fliplr", "scale",
"bgr", "perspective",
"mosaic", "flipud",
"mixup", "fliplr",
"copy_paste", "bgr",
"conf", "mosaic",
"iou", "mixup",
"fraction", "copy_paste",
} "conf",
CFG_INT_KEYS = { # integer-only arguments "iou",
"epochs", "fraction",
"patience", }
"workers", )
"seed", CFG_INT_KEYS = frozenset(
"close_mosaic", { # integer-only arguments
"mask_ratio", "epochs",
"max_det", "patience",
"vid_stride", "workers",
"line_width", "seed",
"nbs", "close_mosaic",
"save_period", "mask_ratio",
} "max_det",
CFG_BOOL_KEYS = { # boolean-only arguments "vid_stride",
"save", "line_width",
"exist_ok", "nbs",
"verbose", "save_period",
"deterministic", }
"single_cls", )
"rect", CFG_BOOL_KEYS = frozenset(
"cos_lr", { # boolean-only arguments
"overlap_mask", "save",
"val", "exist_ok",
"save_json", "verbose",
"save_hybrid", "deterministic",
"half", "single_cls",
"dnn", "rect",
"plots", "cos_lr",
"show", "overlap_mask",
"save_txt", "val",
"save_conf", "save_json",
"save_crop", "save_hybrid",
"save_frames", "half",
"show_labels", "dnn",
"show_conf", "plots",
"visualize", "show",
"augment", "save_txt",
"agnostic_nms", "save_conf",
"retina_masks", "save_crop",
"show_boxes", "save_frames",
"keras", "show_labels",
"optimize", "show_conf",
"int8", "visualize",
"dynamic", "augment",
"simplify", "agnostic_nms",
"nms", "retina_masks",
"profile", "show_boxes",
"multi_scale", "keras",
} "optimize",
"int8",
"dynamic",
"simplify",
"nms",
"profile",
"multi_scale",
}
)
def cfg2dict(cfg): def cfg2dict(cfg):
@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
- Prints detailed error messages for each mismatched key to help users correct their configurations. - Prints detailed error messages for each mismatched key to help users correct their configurations.
""" """
custom = _handle_deprecation(custom) custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))
if mismatched := [k for k in custom_keys if k not in base_keys]: if mismatched := [k for k in custom_keys if k not in base_keys]:
from difflib import get_close_matches from difflib import get_close_matches

View file

@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
ch = [ch] ch = [ch]
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args base_modules = frozenset(
m = ( {
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in {
Classify, Classify,
Conv, Conv,
ConvTranspose, ConvTranspose,
@ -1001,33 +989,49 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
PSA, PSA,
SCDown, SCDown,
C2fCIB, C2fCIB,
}: }
)
repeat_modules = frozenset( # modules with 'repeat' arguments
{
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
C2fAttn,
C3,
C3TR,
C3Ghost,
C3x,
RepC3,
C2fPSA,
C2fCIB,
C2PSA,
}
)
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = (
getattr(torch.nn, m[3:])
if "nn." in m
else getattr(__import__("torchvision").ops, m[16:])
if "torchvision.ops." in m
else globals()[m]
) # get module
for j, a in enumerate(args):
if isinstance(a, str):
with contextlib.suppress(ValueError):
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
if m in base_modules:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(min(c2, max_channels) * width, 8) c2 = make_divisible(min(c2, max_channels) * width, 8)
if m is C2fAttn: if m is C2fAttn: # set 1) embed channels and 2) num heads
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
args[2] = int( args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
) # num heads
args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]
if m in { if m in repeat_modules:
BottleneckCSP,
C1,
C2,
C2f,
C3k2,
C2fAttn,
C3,
C3TR,
C3Ghost,
C3x,
RepC3,
C2fPSA,
C2fCIB,
C2PSA,
}:
args.insert(2, n) # number of repeats args.insert(2, n) # number of repeats
n = 1 n = 1
if m is C3k2: # for M/L/X sizes if m is C3k2: # for M/L/X sizes
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args[3] = True args[3] = True
elif m is AIFI: elif m is AIFI:
args = [ch[f], *args] args = [ch[f], *args]
elif m in {HGStem, HGBlock}: elif m in frozenset({HGStem, HGBlock}):
c1, cm, c2 = ch[f], args[0], args[1] c1, cm, c2 = ch[f], args[0], args[1]
args = [c1, cm, c2, *args[2:]] args = [c1, cm, c2, *args[2:]]
if m is HGBlock: if m is HGBlock:
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]] args = [ch[f]]
elif m is Concat: elif m is Concat:
c2 = sum(ch[x] for x in f) c2 = sum(ch[x] for x in f)
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
args.append([ch[x] for x in f]) args.append([ch[x] for x in f])
if m is Segment: if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8) args[2] = make_divisible(min(args[2], max_channels) * width, 8)
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
m.legacy = legacy m.legacy = legacy
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
args.insert(1, [ch[x] for x in f]) args.insert(1, [ch[x] for x in f])
elif m in {CBLinear, TorchVision, Index}: elif m in frozenset({CBLinear, TorchVision, Index}):
c2 = args[0] c2 = args[0]
c1 = ch[f] c1 = ch[f]
args = [c1, c2, *args[1:]] args = [c1, c2, *args[1:]]

View file

@ -55,8 +55,8 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
unmatched_a = list(np.arange(cost_matrix.shape[0])) unmatched_a = list(np.arange(cost_matrix.shape[0]))
unmatched_b = list(np.arange(cost_matrix.shape[1])) unmatched_b = list(np.arange(cost_matrix.shape[1]))
else: else:
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
return matches, unmatched_a, unmatched_b return matches, unmatched_a, unmatched_b

View file

@ -1227,7 +1227,7 @@ class SettingsManager(JSONDict):
def _validate_settings(self): def _validate_settings(self):
"""Validate the current settings and reset if necessary.""" """Validate the current settings and reset if necessary."""
correct_keys = set(self.keys()) == set(self.defaults.keys()) correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())
correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
correct_version = self.get("settings_version", "") == self.version correct_version = self.get("settings_version", "") == self.version

View file

@ -407,7 +407,7 @@ class Instances:
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
seg_len = [b.segments.shape[1] for b in instances_list] seg_len = [b.segments.shape[1] for b in instances_list]
if len(set(seg_len)) > 1: # resample segments if there's different length if len(frozenset(seg_len)) > 1: # resample segments if there's different length
max_len = max(seg_len) max_len = max(seg_len)
cat_segments = np.concatenate( cat_segments = np.concatenate(
[ [