ultralytics 8.1.31 NCNN and CLIP updates (#9235)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
41c2d8d99f
commit
3c179f87cb
8 changed files with 77 additions and 60 deletions
|
|
@ -30,8 +30,8 @@ from ultralytics.utils import (
|
|||
)
|
||||
|
||||
# Define valid tasks and modes
|
||||
MODES = "train", "val", "predict", "export", "track", "benchmark"
|
||||
TASKS = "detect", "segment", "classify", "pose", "obb"
|
||||
MODES = {"train", "val", "predict", "export", "track", "benchmark"}
|
||||
TASKS = {"detect", "segment", "classify", "pose", "obb"}
|
||||
TASK2DATA = {
|
||||
"detect": "coco8.yaml",
|
||||
"segment": "coco8-seg.yaml",
|
||||
|
|
@ -93,8 +93,8 @@ CLI_HELP_MSG = f"""
|
|||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"
|
||||
CFG_FRACTION_KEYS = (
|
||||
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
|
||||
CFG_FRACTION_KEYS = {
|
||||
"dropout",
|
||||
"iou",
|
||||
"lr0",
|
||||
|
|
@ -118,8 +118,8 @@ CFG_FRACTION_KEYS = (
|
|||
"conf",
|
||||
"iou",
|
||||
"fraction",
|
||||
) # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = (
|
||||
} # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = {
|
||||
"epochs",
|
||||
"patience",
|
||||
"batch",
|
||||
|
|
@ -133,8 +133,8 @@ CFG_INT_KEYS = (
|
|||
"workspace",
|
||||
"nbs",
|
||||
"save_period",
|
||||
)
|
||||
CFG_BOOL_KEYS = (
|
||||
}
|
||||
CFG_BOOL_KEYS = {
|
||||
"save",
|
||||
"exist_ok",
|
||||
"verbose",
|
||||
|
|
@ -169,7 +169,7 @@ CFG_BOOL_KEYS = (
|
|||
"nms",
|
||||
"profile",
|
||||
"multi_scale",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def cfg2dict(cfg):
|
||||
|
|
@ -219,33 +219,46 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")
|
||||
|
||||
# Type and Value checks
|
||||
check_cfg(cfg)
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
|
||||
|
||||
def check_cfg(cfg, hard=True):
|
||||
"""Check Ultralytics configuration argument types and values."""
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
cfg[k] = float(v)
|
||||
elif k in CFG_FRACTION_KEYS:
|
||||
if not isinstance(v, (int, float)):
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')"
|
||||
)
|
||||
cfg[k] = float(v)
|
||||
if not (0.0 <= v <= 1.0):
|
||||
raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.")
|
||||
elif k in CFG_INT_KEYS and not isinstance(v, int):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||
)
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')"
|
||||
)
|
||||
cfg[k] = int(v)
|
||||
elif k in CFG_BOOL_KEYS and not isinstance(v, bool):
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
||||
)
|
||||
|
||||
# Return instance
|
||||
return IterableSimpleNamespace(**cfg)
|
||||
if hard:
|
||||
raise TypeError(
|
||||
f"'{k}={v}' is of invalid type {type(v).__name__}. "
|
||||
f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')"
|
||||
)
|
||||
cfg[k] = bool(v)
|
||||
|
||||
|
||||
def get_save_dir(args, name=None):
|
||||
|
|
@ -464,10 +477,10 @@ def entrypoint(debug=""):
|
|||
overrides = {} # basic overrides, i.e. imgsz=320
|
||||
for a in merge_equals_args(args): # merge spaces around '=' sign
|
||||
if a.startswith("--"):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
|
||||
a = a[2:]
|
||||
if a.endswith(","):
|
||||
LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
|
||||
a = a[:-1]
|
||||
if "=" in a:
|
||||
try:
|
||||
|
|
@ -504,7 +517,7 @@ def entrypoint(debug=""):
|
|||
mode = overrides.get("mode")
|
||||
if mode is None:
|
||||
mode = DEFAULT_CFG.mode or "predict"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.")
|
||||
elif mode not in MODES:
|
||||
raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}")
|
||||
|
||||
|
|
@ -520,7 +533,7 @@ def entrypoint(debug=""):
|
|||
model = overrides.pop("model", DEFAULT_CFG.model)
|
||||
if model is None:
|
||||
model = "yolov8n.pt"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.")
|
||||
overrides["model"] = model
|
||||
stem = Path(model).stem.lower()
|
||||
if "rtdetr" in stem: # guess architecture
|
||||
|
|
@ -554,15 +567,15 @@ def entrypoint(debug=""):
|
|||
# Mode
|
||||
if mode in ("predict", "track") and "source" not in overrides:
|
||||
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ("train", "val"):
|
||||
if "data" not in overrides and "resume" not in overrides:
|
||||
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
|
||||
elif mode == "export":
|
||||
if "format" not in overrides:
|
||||
overrides["format"] = DEFAULT_CFG.format or "torchscript"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.")
|
||||
LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.")
|
||||
|
||||
# Run command in python
|
||||
getattr(model, mode)(**overrides) # default args from model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue