Use frozenset() (#18785)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-21 01:38:00 +01:00 committed by GitHub
parent fb3e5adfd7
commit 9341c1df76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 143 additions and 131 deletions

View file

@ -47,8 +47,8 @@ SOLUTION_MAP = {
}
# Define valid tasks and modes
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
TASKS = {"detect", "segment", "classify", "pose", "obb"}
MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"})
TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"})
TASK2DATA = {
"detect": "coco8.yaml",
"segment": "coco8-seg.yaml",
@ -70,7 +70,7 @@ TASK2METRIC = {
"pose": "metrics/mAP50-95(P)",
"obb": "metrics/mAP50-95(B)",
}
MODELS = {TASK2MODEL[task] for task in TASKS}
MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
SOLUTIONS_HELP_MSG = f"""
@ -144,90 +144,98 @@ CLI_HELP_MSG = f"""
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
"warmup_epochs",
"box",
"cls",
"dfl",
"degrees",
"shear",
"time",
"workspace",
"batch",
}
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
"dropout",
"lr0",
"lrf",
"momentum",
"weight_decay",
"warmup_momentum",
"warmup_bias_lr",
"hsv_h",
"hsv_s",
"hsv_v",
"translate",
"scale",
"perspective",
"flipud",
"fliplr",
"bgr",
"mosaic",
"mixup",
"copy_paste",
"conf",
"iou",
"fraction",
}
CFG_INT_KEYS = { # integer-only arguments
"epochs",
"patience",
"workers",
"seed",
"close_mosaic",
"mask_ratio",
"max_det",
"vid_stride",
"line_width",
"nbs",
"save_period",
}
CFG_BOOL_KEYS = { # boolean-only arguments
"save",
"exist_ok",
"verbose",
"deterministic",
"single_cls",
"rect",
"cos_lr",
"overlap_mask",
"val",
"save_json",
"save_hybrid",
"half",
"dnn",
"plots",
"show",
"save_txt",
"save_conf",
"save_crop",
"save_frames",
"show_labels",
"show_conf",
"visualize",
"augment",
"agnostic_nms",
"retina_masks",
"show_boxes",
"keras",
"optimize",
"int8",
"dynamic",
"simplify",
"nms",
"profile",
"multi_scale",
}
CFG_FLOAT_KEYS = frozenset(
{ # integer or float arguments, i.e. x=2 and x=2.0
"warmup_epochs",
"box",
"cls",
"dfl",
"degrees",
"shear",
"time",
"workspace",
"batch",
}
)
CFG_FRACTION_KEYS = frozenset(
{ # fractional float arguments with 0.0<=values<=1.0
"dropout",
"lr0",
"lrf",
"momentum",
"weight_decay",
"warmup_momentum",
"warmup_bias_lr",
"hsv_h",
"hsv_s",
"hsv_v",
"translate",
"scale",
"perspective",
"flipud",
"fliplr",
"bgr",
"mosaic",
"mixup",
"copy_paste",
"conf",
"iou",
"fraction",
}
)
CFG_INT_KEYS = frozenset(
{ # integer-only arguments
"epochs",
"patience",
"workers",
"seed",
"close_mosaic",
"mask_ratio",
"max_det",
"vid_stride",
"line_width",
"nbs",
"save_period",
}
)
CFG_BOOL_KEYS = frozenset(
{ # boolean-only arguments
"save",
"exist_ok",
"verbose",
"deterministic",
"single_cls",
"rect",
"cos_lr",
"overlap_mask",
"val",
"save_json",
"save_hybrid",
"half",
"dnn",
"plots",
"show",
"save_txt",
"save_conf",
"save_crop",
"save_frames",
"show_labels",
"show_conf",
"visualize",
"augment",
"agnostic_nms",
"retina_masks",
"show_boxes",
"keras",
"optimize",
"int8",
"dynamic",
"simplify",
"nms",
"profile",
"multi_scale",
}
)
def cfg2dict(cfg):
@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
- Prints detailed error messages for each mismatched key to help users correct their configurations.
"""
custom = _handle_deprecation(custom)
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))
if mismatched := [k for k in custom_keys if k not in base_keys]:
from difflib import get_close_matches