From 9341c1df7614d4b502b0aeaa15c336c0f1f68c6e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 21 Jan 2025 01:38:00 +0100 Subject: [PATCH] Use frozenset() (#18785) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- ultralytics/cfg/__init__.py | 184 +++++++++++++------------ ultralytics/nn/tasks.py | 82 +++++------ ultralytics/trackers/utils/matching.py | 4 +- ultralytics/utils/__init__.py | 2 +- ultralytics/utils/instance.py | 2 +- 5 files changed, 143 insertions(+), 131 deletions(-) diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 8625f7c9..a98de949 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -47,8 +47,8 @@ SOLUTION_MAP = { } # Define valid tasks and modes -MODES = {"train", "val", "predict", "export", "track", "benchmark"} -TASKS = {"detect", "segment", "classify", "pose", "obb"} +MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"}) +TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"}) TASK2DATA = { "detect": "coco8.yaml", "segment": "coco8-seg.yaml", @@ -70,7 +70,7 @@ TASK2METRIC = { "pose": "metrics/mAP50-95(P)", "obb": "metrics/mAP50-95(B)", } -MODELS = {TASK2MODEL[task] for task in TASKS} +MODELS = frozenset({TASK2MODEL[task] for task in TASKS}) ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] SOLUTIONS_HELP_MSG = f""" @@ -144,90 +144,98 @@ CLI_HELP_MSG = f""" """ # Define keys for arg type checks -CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 - "warmup_epochs", - "box", - "cls", - "dfl", - "degrees", - "shear", - "time", - "workspace", - "batch", -} -CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 - "dropout", - "lr0", - "lrf", - "momentum", - "weight_decay", - "warmup_momentum", - "warmup_bias_lr", - "hsv_h", - "hsv_s", - "hsv_v", - "translate", - "scale", - "perspective", - "flipud", - "fliplr", - "bgr", - "mosaic", - "mixup", - "copy_paste", - "conf", - "iou", - "fraction", -} -CFG_INT_KEYS = { # integer-only arguments - "epochs", - "patience", - "workers", - "seed", - "close_mosaic", - "mask_ratio", - "max_det", - "vid_stride", - "line_width", - "nbs", - "save_period", -} -CFG_BOOL_KEYS = { # boolean-only arguments - "save", - "exist_ok", - "verbose", - "deterministic", - "single_cls", - "rect", - "cos_lr", - "overlap_mask", - "val", - "save_json", - "save_hybrid", - "half", - "dnn", - "plots", - "show", - "save_txt", - "save_conf", - "save_crop", - "save_frames", - "show_labels", - "show_conf", - "visualize", - "augment", - "agnostic_nms", - "retina_masks", - "show_boxes", - "keras", - "optimize", - "int8", - "dynamic", - "simplify", - "nms", - "profile", - "multi_scale", -} +CFG_FLOAT_KEYS = frozenset( + { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", + } +) +CFG_FRACTION_KEYS = frozenset( + { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", + } +) +CFG_INT_KEYS = frozenset( + { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", + } +) +CFG_BOOL_KEYS = frozenset( + { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", + } +) def cfg2dict(cfg): @@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None): - Prints detailed error messages for each mismatched key to help users correct their configurations. """ custom = _handle_deprecation(custom) - base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom)) if mismatched := [k for k in custom_keys if k not in base_keys]: from difflib import get_close_matches diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index f5dec3c5..cbabe49a 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") ch = [ch] layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out - for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args - m = ( - getattr(torch.nn, m[3:]) - if "nn." in m - else getattr(__import__("torchvision").ops, m[16:]) - if "torchvision.ops." in m - else globals()[m] - ) # get module - for j, a in enumerate(args): - if isinstance(a, str): - with contextlib.suppress(ValueError): - args[j] = locals()[a] if a in locals() else ast.literal_eval(a) - n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain - if m in { + base_modules = frozenset( + { Classify, Conv, ConvTranspose, @@ -1001,33 +989,49 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) PSA, SCDown, C2fCIB, - }: + } + ) + repeat_modules = frozenset( # modules with 'repeat' arguments + { + BottleneckCSP, + C1, + C2, + C2f, + C3k2, + C2fAttn, + C3, + C3TR, + C3Ghost, + C3x, + RepC3, + C2fPSA, + C2fCIB, + C2PSA, + } + ) + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = ( + getattr(torch.nn, m[3:]) + if "nn." in m + else getattr(__import__("torchvision").ops, m[16:]) + if "torchvision.ops." in m + else globals()[m] + ) # get module + for j, a in enumerate(args): + if isinstance(a, str): + with contextlib.suppress(ValueError): + args[j] = locals()[a] if a in locals() else ast.literal_eval(a) + n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain + if m in base_modules: c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) c2 = make_divisible(min(c2, max_channels) * width, 8) - if m is C2fAttn: - args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels - args[2] = int( - max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2] - ) # num heads + if m is C2fAttn: # set 1) embed channels and 2) num heads + args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) + args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]) args = [c1, c2, *args[1:]] - if m in { - BottleneckCSP, - C1, - C2, - C2f, - C3k2, - C2fAttn, - C3, - C3TR, - C3Ghost, - C3x, - RepC3, - C2fPSA, - C2fCIB, - C2PSA, - }: + if m in repeat_modules: args.insert(2, n) # number of repeats n = 1 if m is C3k2: # for M/L/X sizes @@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args[3] = True elif m is AIFI: args = [ch[f], *args] - elif m in {HGStem, HGBlock}: + elif m in frozenset({HGStem, HGBlock}): c1, cm, c2 = ch[f], args[0], args[1] args = [c1, cm, c2, *args[2:]] if m is HGBlock: @@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: + elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}): args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) @@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) m.legacy = legacy elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 args.insert(1, [ch[x] for x in f]) - elif m in {CBLinear, TorchVision, Index}: + elif m in frozenset({CBLinear, TorchVision, Index}): c2 = args[0] c1 = ch[f] args = [c1, c2, *args[1:]] diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index e4b354f1..f15f64df 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -55,8 +55,8 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr unmatched_a = list(np.arange(cost_matrix.shape[0])) unmatched_b = list(np.arange(cost_matrix.shape[1])) else: - unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) - unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0])) + unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1])) return matches, unmatched_a, unmatched_b diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index fa303923..753a02f7 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -1227,7 +1227,7 @@ class SettingsManager(JSONDict): def _validate_settings(self): """Validate the current settings and reset if necessary.""" - correct_keys = set(self.keys()) == set(self.defaults.keys()) + correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys()) correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) correct_version = self.get("settings_version", "") == self.version diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py index e92a9614..71ce3626 100644 --- a/ultralytics/utils/instance.py +++ b/ultralytics/utils/instance.py @@ -407,7 +407,7 @@ class Instances: cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) seg_len = [b.segments.shape[1] for b in instances_list] - if len(set(seg_len)) > 1: # resample segments if there's different length + if len(frozenset(seg_len)) > 1: # resample segments if there's different length max_len = max(seg_len) cat_segments = np.concatenate( [