diff --git a/docs/en/reference/engine/exporter.md b/docs/en/reference/engine/exporter.md
index 98e81a8a..a0d1822d 100644
--- a/docs/en/reference/engine/exporter.md
+++ b/docs/en/reference/engine/exporter.md
@@ -23,6 +23,10 @@ keywords: YOLOv8, export formats, ONNX, TensorRT, CoreML, machine learning model
+## ::: ultralytics.engine.exporter.validate_args
+
+
+
## ::: ultralytics.engine.exporter.gd_outputs
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 177afda2..824a1aa7 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.3.52"
+__version__ = "8.3.53"
import os
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index ae84cab9..ea2bc01e 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -101,23 +101,47 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d
def export_formats():
"""Ultralytics YOLO export formats."""
x = [
- ["PyTorch", "-", ".pt", True, True],
- ["TorchScript", "torchscript", ".torchscript", True, True],
- ["ONNX", "onnx", ".onnx", True, True],
- ["OpenVINO", "openvino", "_openvino_model", True, False],
- ["TensorRT", "engine", ".engine", False, True],
- ["CoreML", "coreml", ".mlpackage", True, False],
- ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True],
- ["TensorFlow GraphDef", "pb", ".pb", True, True],
- ["TensorFlow Lite", "tflite", ".tflite", True, False],
- ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
- ["TensorFlow.js", "tfjs", "_web_model", True, False],
- ["PaddlePaddle", "paddle", "_paddle_model", True, True],
- ["MNN", "mnn", ".mnn", True, True],
- ["NCNN", "ncnn", "_ncnn_model", True, True],
- ["IMX", "imx", "_imx_model", True, True],
+ ["PyTorch", "-", ".pt", True, True, []],
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["optimize", "batch"]],
+ ["ONNX", "onnx", ".onnx", True, True, ["half", "dynamic", "simplify", "opset", "batch"]],
+ ["OpenVINO", "openvino", "_openvino_model", True, False, ["half", "int8", "batch"]],
+ ["TensorRT", "engine", ".engine", False, True, ["half", "dynamic", "simplify", "int8", "batch"]],
+ ["CoreML", "coreml", ".mlpackage", True, False, ["half", "int8", "nms", "batch"]],
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["keras", "int8", "batch"]],
+ ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["half", "int8", "batch"]],
+ ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["half", "int8", "batch"]],
+ ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
+ ["MNN", "mnn", ".mnn", True, True, ["batch", "int8", "half"]],
+ ["NCNN", "ncnn", "_ncnn_model", True, True, ["half", "batch"]],
+ ["IMX", "imx", "_imx_model", True, True, ["int8"]],
]
- return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x)))
+ return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x)))
+
+
+def validate_args(format, passed_args, valid_args):
+ """
+ Validates arguments based on format.
+
+ Args:
+ format (str): The export format.
+ passed_args (Namespace): The arguments used during export.
+ valid_args (dict): List of valid arguments for the format.
+
+ Raises:
+ AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed.
+ """
+ # Only check valid usage of these args
+ export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"]
+
+ assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed."
+ custom = {"batch": 1, "data": None, "device": None} # exporter defaults
+ default_args = get_cfg(DEFAULT_CFG, custom)
+ for arg in export_args:
+ not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None)
+ if not_default:
+ assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'"
def gd_outputs(gd):
@@ -182,7 +206,8 @@ class Exporter:
fmt = "engine"
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
fmt = "coreml"
- fmts = tuple(export_formats()["Argument"][1:]) # available export formats
+ fmts_dict = export_formats()
+ fmts = tuple(fmts_dict["Argument"][1:]) # available export formats
if fmt not in fmts:
import difflib
@@ -224,7 +249,8 @@ class Exporter:
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
self.device = select_device("cpu" if self.args.device is None else self.args.device)
- # Checks
+ # Argument compatibility checks
+ validate_args(fmt, self.args, fmts_dict["Arguments"][flags.index(True) + 1])
if imx and not self.args.int8:
LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.")
self.args.int8 = True
diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py
index e65d1288..e5a6c22a 100644
--- a/ultralytics/utils/benchmarks.py
+++ b/ultralytics/utils/benchmarks.py
@@ -90,7 +90,7 @@ def benchmark(
y = []
t0 = time.time()
- for i, (name, format, suffix, cpu, gpu) in enumerate(zip(*export_formats().values())):
+ for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())):
emoji, filename = "❌", None # export defaults
try:
# Checks