Use frozenset() (#18785)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
fb3e5adfd7
commit
9341c1df76
5 changed files with 143 additions and 131 deletions
|
|
@ -47,8 +47,8 @@ SOLUTION_MAP = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# Define valid tasks and modes
|
# Define valid tasks and modes
|
||||||
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
|
MODES = frozenset({"train", "val", "predict", "export", "track", "benchmark"})
|
||||||
TASKS = {"detect", "segment", "classify", "pose", "obb"}
|
TASKS = frozenset({"detect", "segment", "classify", "pose", "obb"})
|
||||||
TASK2DATA = {
|
TASK2DATA = {
|
||||||
"detect": "coco8.yaml",
|
"detect": "coco8.yaml",
|
||||||
"segment": "coco8-seg.yaml",
|
"segment": "coco8-seg.yaml",
|
||||||
|
|
@ -70,7 +70,7 @@ TASK2METRIC = {
|
||||||
"pose": "metrics/mAP50-95(P)",
|
"pose": "metrics/mAP50-95(P)",
|
||||||
"obb": "metrics/mAP50-95(B)",
|
"obb": "metrics/mAP50-95(B)",
|
||||||
}
|
}
|
||||||
MODELS = {TASK2MODEL[task] for task in TASKS}
|
MODELS = frozenset({TASK2MODEL[task] for task in TASKS})
|
||||||
|
|
||||||
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
|
||||||
SOLUTIONS_HELP_MSG = f"""
|
SOLUTIONS_HELP_MSG = f"""
|
||||||
|
|
@ -144,90 +144,98 @@ CLI_HELP_MSG = f"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Define keys for arg type checks
|
# Define keys for arg type checks
|
||||||
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
|
CFG_FLOAT_KEYS = frozenset(
|
||||||
"warmup_epochs",
|
{ # integer or float arguments, i.e. x=2 and x=2.0
|
||||||
"box",
|
"warmup_epochs",
|
||||||
"cls",
|
"box",
|
||||||
"dfl",
|
"cls",
|
||||||
"degrees",
|
"dfl",
|
||||||
"shear",
|
"degrees",
|
||||||
"time",
|
"shear",
|
||||||
"workspace",
|
"time",
|
||||||
"batch",
|
"workspace",
|
||||||
}
|
"batch",
|
||||||
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
|
}
|
||||||
"dropout",
|
)
|
||||||
"lr0",
|
CFG_FRACTION_KEYS = frozenset(
|
||||||
"lrf",
|
{ # fractional float arguments with 0.0<=values<=1.0
|
||||||
"momentum",
|
"dropout",
|
||||||
"weight_decay",
|
"lr0",
|
||||||
"warmup_momentum",
|
"lrf",
|
||||||
"warmup_bias_lr",
|
"momentum",
|
||||||
"hsv_h",
|
"weight_decay",
|
||||||
"hsv_s",
|
"warmup_momentum",
|
||||||
"hsv_v",
|
"warmup_bias_lr",
|
||||||
"translate",
|
"hsv_h",
|
||||||
"scale",
|
"hsv_s",
|
||||||
"perspective",
|
"hsv_v",
|
||||||
"flipud",
|
"translate",
|
||||||
"fliplr",
|
"scale",
|
||||||
"bgr",
|
"perspective",
|
||||||
"mosaic",
|
"flipud",
|
||||||
"mixup",
|
"fliplr",
|
||||||
"copy_paste",
|
"bgr",
|
||||||
"conf",
|
"mosaic",
|
||||||
"iou",
|
"mixup",
|
||||||
"fraction",
|
"copy_paste",
|
||||||
}
|
"conf",
|
||||||
CFG_INT_KEYS = { # integer-only arguments
|
"iou",
|
||||||
"epochs",
|
"fraction",
|
||||||
"patience",
|
}
|
||||||
"workers",
|
)
|
||||||
"seed",
|
CFG_INT_KEYS = frozenset(
|
||||||
"close_mosaic",
|
{ # integer-only arguments
|
||||||
"mask_ratio",
|
"epochs",
|
||||||
"max_det",
|
"patience",
|
||||||
"vid_stride",
|
"workers",
|
||||||
"line_width",
|
"seed",
|
||||||
"nbs",
|
"close_mosaic",
|
||||||
"save_period",
|
"mask_ratio",
|
||||||
}
|
"max_det",
|
||||||
CFG_BOOL_KEYS = { # boolean-only arguments
|
"vid_stride",
|
||||||
"save",
|
"line_width",
|
||||||
"exist_ok",
|
"nbs",
|
||||||
"verbose",
|
"save_period",
|
||||||
"deterministic",
|
}
|
||||||
"single_cls",
|
)
|
||||||
"rect",
|
CFG_BOOL_KEYS = frozenset(
|
||||||
"cos_lr",
|
{ # boolean-only arguments
|
||||||
"overlap_mask",
|
"save",
|
||||||
"val",
|
"exist_ok",
|
||||||
"save_json",
|
"verbose",
|
||||||
"save_hybrid",
|
"deterministic",
|
||||||
"half",
|
"single_cls",
|
||||||
"dnn",
|
"rect",
|
||||||
"plots",
|
"cos_lr",
|
||||||
"show",
|
"overlap_mask",
|
||||||
"save_txt",
|
"val",
|
||||||
"save_conf",
|
"save_json",
|
||||||
"save_crop",
|
"save_hybrid",
|
||||||
"save_frames",
|
"half",
|
||||||
"show_labels",
|
"dnn",
|
||||||
"show_conf",
|
"plots",
|
||||||
"visualize",
|
"show",
|
||||||
"augment",
|
"save_txt",
|
||||||
"agnostic_nms",
|
"save_conf",
|
||||||
"retina_masks",
|
"save_crop",
|
||||||
"show_boxes",
|
"save_frames",
|
||||||
"keras",
|
"show_labels",
|
||||||
"optimize",
|
"show_conf",
|
||||||
"int8",
|
"visualize",
|
||||||
"dynamic",
|
"augment",
|
||||||
"simplify",
|
"agnostic_nms",
|
||||||
"nms",
|
"retina_masks",
|
||||||
"profile",
|
"show_boxes",
|
||||||
"multi_scale",
|
"keras",
|
||||||
}
|
"optimize",
|
||||||
|
"int8",
|
||||||
|
"dynamic",
|
||||||
|
"simplify",
|
||||||
|
"nms",
|
||||||
|
"profile",
|
||||||
|
"multi_scale",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def cfg2dict(cfg):
|
def cfg2dict(cfg):
|
||||||
|
|
@ -472,7 +480,7 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
||||||
- Prints detailed error messages for each mismatched key to help users correct their configurations.
|
- Prints detailed error messages for each mismatched key to help users correct their configurations.
|
||||||
"""
|
"""
|
||||||
custom = _handle_deprecation(custom)
|
custom = _handle_deprecation(custom)
|
||||||
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
base_keys, custom_keys = (frozenset(x.keys()) for x in (base, custom))
|
||||||
if mismatched := [k for k in custom_keys if k not in base_keys]:
|
if mismatched := [k for k in custom_keys if k not in base_keys]:
|
||||||
from difflib import get_close_matches
|
from difflib import get_close_matches
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -954,20 +954,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||||
ch = [ch]
|
ch = [ch]
|
||||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
base_modules = frozenset(
|
||||||
m = (
|
{
|
||||||
getattr(torch.nn, m[3:])
|
|
||||||
if "nn." in m
|
|
||||||
else getattr(__import__("torchvision").ops, m[16:])
|
|
||||||
if "torchvision.ops." in m
|
|
||||||
else globals()[m]
|
|
||||||
) # get module
|
|
||||||
for j, a in enumerate(args):
|
|
||||||
if isinstance(a, str):
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
|
||||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
|
||||||
if m in {
|
|
||||||
Classify,
|
Classify,
|
||||||
Conv,
|
Conv,
|
||||||
ConvTranspose,
|
ConvTranspose,
|
||||||
|
|
@ -1001,33 +989,49 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
PSA,
|
PSA,
|
||||||
SCDown,
|
SCDown,
|
||||||
C2fCIB,
|
C2fCIB,
|
||||||
}:
|
}
|
||||||
|
)
|
||||||
|
repeat_modules = frozenset( # modules with 'repeat' arguments
|
||||||
|
{
|
||||||
|
BottleneckCSP,
|
||||||
|
C1,
|
||||||
|
C2,
|
||||||
|
C2f,
|
||||||
|
C3k2,
|
||||||
|
C2fAttn,
|
||||||
|
C3,
|
||||||
|
C3TR,
|
||||||
|
C3Ghost,
|
||||||
|
C3x,
|
||||||
|
RepC3,
|
||||||
|
C2fPSA,
|
||||||
|
C2fCIB,
|
||||||
|
C2PSA,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||||
|
m = (
|
||||||
|
getattr(torch.nn, m[3:])
|
||||||
|
if "nn." in m
|
||||||
|
else getattr(__import__("torchvision").ops, m[16:])
|
||||||
|
if "torchvision.ops." in m
|
||||||
|
else globals()[m]
|
||||||
|
) # get module
|
||||||
|
for j, a in enumerate(args):
|
||||||
|
if isinstance(a, str):
|
||||||
|
with contextlib.suppress(ValueError):
|
||||||
|
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||||
|
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||||
|
if m in base_modules:
|
||||||
c1, c2 = ch[f], args[0]
|
c1, c2 = ch[f], args[0]
|
||||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||||
if m is C2fAttn:
|
if m is C2fAttn: # set 1) embed channels and 2) num heads
|
||||||
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels
|
args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8)
|
||||||
args[2] = int(
|
args[2] = int(max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2])
|
||||||
max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2]
|
|
||||||
) # num heads
|
|
||||||
|
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
if m in {
|
if m in repeat_modules:
|
||||||
BottleneckCSP,
|
|
||||||
C1,
|
|
||||||
C2,
|
|
||||||
C2f,
|
|
||||||
C3k2,
|
|
||||||
C2fAttn,
|
|
||||||
C3,
|
|
||||||
C3TR,
|
|
||||||
C3Ghost,
|
|
||||||
C3x,
|
|
||||||
RepC3,
|
|
||||||
C2fPSA,
|
|
||||||
C2fCIB,
|
|
||||||
C2PSA,
|
|
||||||
}:
|
|
||||||
args.insert(2, n) # number of repeats
|
args.insert(2, n) # number of repeats
|
||||||
n = 1
|
n = 1
|
||||||
if m is C3k2: # for M/L/X sizes
|
if m is C3k2: # for M/L/X sizes
|
||||||
|
|
@ -1036,7 +1040,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
args[3] = True
|
args[3] = True
|
||||||
elif m is AIFI:
|
elif m is AIFI:
|
||||||
args = [ch[f], *args]
|
args = [ch[f], *args]
|
||||||
elif m in {HGStem, HGBlock}:
|
elif m in frozenset({HGStem, HGBlock}):
|
||||||
c1, cm, c2 = ch[f], args[0], args[1]
|
c1, cm, c2 = ch[f], args[0], args[1]
|
||||||
args = [c1, cm, c2, *args[2:]]
|
args = [c1, cm, c2, *args[2:]]
|
||||||
if m is HGBlock:
|
if m is HGBlock:
|
||||||
|
|
@ -1048,7 +1052,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
args = [ch[f]]
|
args = [ch[f]]
|
||||||
elif m is Concat:
|
elif m is Concat:
|
||||||
c2 = sum(ch[x] for x in f)
|
c2 = sum(ch[x] for x in f)
|
||||||
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}:
|
elif m in frozenset({Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}):
|
||||||
args.append([ch[x] for x in f])
|
args.append([ch[x] for x in f])
|
||||||
if m is Segment:
|
if m is Segment:
|
||||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||||
|
|
@ -1056,7 +1060,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
m.legacy = legacy
|
m.legacy = legacy
|
||||||
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1
|
||||||
args.insert(1, [ch[x] for x in f])
|
args.insert(1, [ch[x] for x in f])
|
||||||
elif m in {CBLinear, TorchVision, Index}:
|
elif m in frozenset({CBLinear, TorchVision, Index}):
|
||||||
c2 = args[0]
|
c2 = args[0]
|
||||||
c1 = ch[f]
|
c1 = ch[f]
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
|
|
|
||||||
|
|
@ -55,8 +55,8 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
|
||||||
unmatched_a = list(np.arange(cost_matrix.shape[0]))
|
unmatched_a = list(np.arange(cost_matrix.shape[0]))
|
||||||
unmatched_b = list(np.arange(cost_matrix.shape[1]))
|
unmatched_b = list(np.arange(cost_matrix.shape[1]))
|
||||||
else:
|
else:
|
||||||
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0]))
|
unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
|
||||||
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1]))
|
unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
|
||||||
|
|
||||||
return matches, unmatched_a, unmatched_b
|
return matches, unmatched_a, unmatched_b
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1227,7 +1227,7 @@ class SettingsManager(JSONDict):
|
||||||
|
|
||||||
def _validate_settings(self):
|
def _validate_settings(self):
|
||||||
"""Validate the current settings and reset if necessary."""
|
"""Validate the current settings and reset if necessary."""
|
||||||
correct_keys = set(self.keys()) == set(self.defaults.keys())
|
correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())
|
||||||
correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
|
correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
|
||||||
correct_version = self.get("settings_version", "") == self.version
|
correct_version = self.get("settings_version", "") == self.version
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -407,7 +407,7 @@ class Instances:
|
||||||
|
|
||||||
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
||||||
seg_len = [b.segments.shape[1] for b in instances_list]
|
seg_len = [b.segments.shape[1] for b in instances_list]
|
||||||
if len(set(seg_len)) > 1: # resample segments if there's different length
|
if len(frozenset(seg_len)) > 1: # resample segments if there's different length
|
||||||
max_len = max(seg_len)
|
max_len = max(seg_len)
|
||||||
cat_segments = np.concatenate(
|
cat_segments = np.concatenate(
|
||||||
[
|
[
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue