diff --git a/docs/en/macros/export-table.md b/docs/en/macros/export-table.md
index a8e8e260..8bf018e5 100644
--- a/docs/en/macros/export-table.md
+++ b/docs/en/macros/export-table.md
@@ -1,18 +1,18 @@
-| Format | `format` Argument | Model | Metadata | Arguments |
-| ------------------------------------------------- | ----------------- | ----------------------------------------------- | -------- | -------------------------------------------------------------------- |
-| [PyTorch](https://pytorch.org/) | - | `{{ model_name or "yolo11n" }}.pt` | ✅ | - |
-| [TorchScript](../integrations/torchscript.md) | `torchscript` | `{{ model_name or "yolo11n" }}.torchscript` | ✅ | `imgsz`, `optimize`, `batch` |
-| [ONNX](../integrations/onnx.md) | `onnx` | `{{ model_name or "yolo11n" }}.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset`, `batch` |
-| [OpenVINO](../integrations/openvino.md) | `openvino` | `{{ model_name or "yolo11n" }}_openvino_model/` | ✅ | `imgsz`, `half`, `dynamic`, `int8`, `batch` |
-| [TensorRT](../integrations/tensorrt.md) | `engine` | `{{ model_name or "yolo11n" }}.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace`, `int8`, `batch` |
-| [CoreML](../integrations/coreml.md) | `coreml` | `{{ model_name or "yolo11n" }}.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
-| [TF SavedModel](../integrations/tf-savedmodel.md) | `saved_model` | `{{ model_name or "yolo11n" }}_saved_model/` | ✅ | `imgsz`, `keras`, `int8`, `batch` |
-| [TF GraphDef](../integrations/tf-graphdef.md) | `pb` | `{{ model_name or "yolo11n" }}.pb` | ❌ | `imgsz`, `batch` |
-| [TF Lite](../integrations/tflite.md) | `tflite` | `{{ model_name or "yolo11n" }}.tflite` | ✅ | `imgsz`, `half`, `int8`, `batch` |
-| [TF Edge TPU](../integrations/edge-tpu.md) | `edgetpu` | `{{ model_name or "yolo11n" }}_edgetpu.tflite` | ✅ | `imgsz` |
-| [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `batch` |
-| [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` |
-| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` |
-| [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` |
-| [IMX500](../integrations/sony-imx500.md) | `imx` | `{{ model_name or "yolov8n" }}_imx_model/` | ✅ | `imgsz`, `int8` |
-| [RKNN](../integrations/rockchip-rknn.md) | `rknn` | `{{ model_name or "yolo11n" }}_rknn_model/` | ✅ | `imgsz`, `batch`, `name` |
+| Format | `format` Argument | Model | Metadata | Arguments |
+| ------------------------------------------------- | ----------------- | ----------------------------------------------- | -------- | --------------------------------------------------------------------------- |
+| [PyTorch](https://pytorch.org/) | - | `{{ model_name or "yolo11n" }}.pt` | ✅ | - |
+| [TorchScript](../integrations/torchscript.md) | `torchscript` | `{{ model_name or "yolo11n" }}.torchscript` | ✅ | `imgsz`, `optimize`, `nms`, `batch` |
+| [ONNX](../integrations/onnx.md) | `onnx` | `{{ model_name or "yolo11n" }}.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset`, `nms`, `batch` |
+| [OpenVINO](../integrations/openvino.md) | `openvino` | `{{ model_name or "yolo11n" }}_openvino_model/` | ✅ | `imgsz`, `half`, `dynamic`, `int8`, `nms`, `batch` |
+| [TensorRT](../integrations/tensorrt.md) | `engine` | `{{ model_name or "yolo11n" }}.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace`, `int8`, `nms`, `batch` |
+| [CoreML](../integrations/coreml.md) | `coreml` | `{{ model_name or "yolo11n" }}.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
+| [TF SavedModel](../integrations/tf-savedmodel.md) | `saved_model` | `{{ model_name or "yolo11n" }}_saved_model/` | ✅ | `imgsz`, `keras`, `int8`, `nms`, `batch` |
+| [TF GraphDef](../integrations/tf-graphdef.md) | `pb` | `{{ model_name or "yolo11n" }}.pb` | ❌ | `imgsz`, `batch` |
+| [TF Lite](../integrations/tflite.md) | `tflite` | `{{ model_name or "yolo11n" }}.tflite` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
+| [TF Edge TPU](../integrations/edge-tpu.md) | `edgetpu` | `{{ model_name or "yolo11n" }}_edgetpu.tflite` | ✅ | `imgsz` |
+| [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
+| [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` |
+| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` |
+| [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` |
+| [IMX500](../integrations/sony-imx500.md) | `imx` | `{{ model_name or "yolov8n" }}_imx_model/` | ✅ | `imgsz`, `int8` |
+| [RKNN](../integrations/rockchip-rknn.md) | `rknn` | `{{ model_name or "yolo11n" }}_rknn_model/` | ✅ | `imgsz`, `batch`, `name` |
diff --git a/docs/en/reference/engine/exporter.md b/docs/en/reference/engine/exporter.md
index a0d1822d..a650b314 100644
--- a/docs/en/reference/engine/exporter.md
+++ b/docs/en/reference/engine/exporter.md
@@ -19,6 +19,10 @@ keywords: YOLOv8, export formats, ONNX, TensorRT, CoreML, machine learning model
+## ::: ultralytics.engine.exporter.NMSModel
+
+
+
## ::: ultralytics.engine.exporter.export_formats
diff --git a/tests/test_exports.py b/tests/test_exports.py
index 970065d6..c364b481 100644
--- a/tests/test_exports.py
+++ b/tests/test_exports.py
@@ -43,23 +43,19 @@ def test_export_openvino():
@pytest.mark.slow
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
@pytest.mark.parametrize(
- "task, dynamic, int8, half, batch",
+ "task, dynamic, int8, half, batch, nms",
[ # generate all combinations but exclude those where both int8 and half are True
- (task, dynamic, int8, half, batch)
- for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
+ (task, dynamic, int8, half, batch, nms)
+ for task, dynamic, int8, half, batch, nms in product(
+ TASKS, [True, False], [True, False], [True, False], [1, 2], [True, False]
+ )
if not (int8 and half) # exclude cases where both int8 and half are True
],
)
-def test_export_openvino_matrix(task, dynamic, int8, half, batch):
+def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
"""Test YOLO model exports to OpenVINO under various configuration matrix conditions."""
file = YOLO(TASK2MODEL[task]).export(
- format="openvino",
- imgsz=32,
- dynamic=dynamic,
- int8=int8,
- half=half,
- batch=batch,
- data=TASK2DATA[task],
+ format="openvino", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, data=TASK2DATA[task], nms=nms
)
if WINDOWS:
# Use unique filenames due to Windows file permissions bug possibly due to latent threaded use
@@ -72,34 +68,26 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
@pytest.mark.slow
@pytest.mark.parametrize(
- "task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False])
+ "task, dynamic, int8, half, batch, simplify, nms",
+ product(TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]),
)
-def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
+def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
"""Test YOLO exports to ONNX format with various configurations and parameters."""
file = YOLO(TASK2MODEL[task]).export(
- format="onnx",
- imgsz=32,
- dynamic=dynamic,
- int8=int8,
- half=half,
- batch=batch,
- simplify=simplify,
+ format="onnx", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, simplify=simplify, nms=nms
)
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
Path(file).unlink() # cleanup
@pytest.mark.slow
-@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
-def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
+@pytest.mark.parametrize(
+ "task, dynamic, int8, half, batch, nms", product(TASKS, [False], [False], [False], [1, 2], [True, False])
+)
+def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):
"""Tests YOLO model exports to TorchScript format under varied configurations."""
file = YOLO(TASK2MODEL[task]).export(
- format="torchscript",
- imgsz=32,
- dynamic=dynamic,
- int8=int8,
- half=half,
- batch=batch,
+ format="torchscript", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
)
YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3
Path(file).unlink() # cleanup
@@ -135,22 +123,19 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
@pytest.mark.parametrize(
- "task, dynamic, int8, half, batch",
+ "task, dynamic, int8, half, batch, nms",
[ # generate all combinations but exclude those where both int8 and half are True
- (task, dynamic, int8, half, batch)
- for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
+ (task, dynamic, int8, half, batch, nms)
+ for task, dynamic, int8, half, batch, nms in product(
+ TASKS, [False], [True, False], [True, False], [1], [True, False]
+ )
if not (int8 and half) # exclude cases where both int8 and half are True
],
)
-def test_export_tflite_matrix(task, dynamic, int8, half, batch):
+def test_export_tflite_matrix(task, dynamic, int8, half, batch, nms):
"""Test YOLO exports to TFLite format considering various export configurations."""
file = YOLO(TASK2MODEL[task]).export(
- format="tflite",
- imgsz=32,
- dynamic=dynamic,
- int8=int8,
- half=half,
- batch=batch,
+ format="tflite", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
)
YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
Path(file).unlink() # cleanup
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 98c1509d..e7be7b76 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-__version__ = "8.3.66"
+__version__ = "8.3.67"
import os
diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py
index ad2e6b07..88e6e533 100644
--- a/ultralytics/engine/exporter.py
+++ b/ultralytics/engine/exporter.py
@@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
)
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
from ultralytics.utils.files import file_size, spaces_in_path
-from ultralytics.utils.ops import Profile
+from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
@@ -111,16 +111,16 @@ def export_formats():
"""Ultralytics YOLO export formats."""
x = [
["PyTorch", "-", ".pt", True, True, []],
- ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
- ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
- ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
- ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
+ ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
+ ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
+ ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
+ ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
- ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
+ ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
- ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
+ ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
- ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
+ ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
@@ -281,6 +281,11 @@ class Exporter:
)
if self.args.int8 and tflite:
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
+ if self.args.nms:
+ if getattr(model, "end2end", False):
+ LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
+ self.args.nms = False
+ self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
if edgetpu:
if not LINUX:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
@@ -344,8 +349,8 @@ class Exporter:
)
y = None
- for _ in range(2):
- y = model(im) # dry runs
+ for _ in range(2): # dry runs
+ y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
if self.args.half and onnx and self.device.type != "cpu":
im, model = im.half(), model.half() # to FP16
@@ -476,7 +481,7 @@ class Exporter:
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
f = self.file.with_suffix(".torchscript")
- ts = torch.jit.trace(self.model, self.im, strict=False)
+ ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f"{prefix} optimizing for mobile...")
@@ -499,7 +504,6 @@ class Exporter:
opset_version = self.args.opset or get_latest_opset()
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
f = str(self.file.with_suffix(".onnx"))
-
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
dynamic = self.args.dynamic
if dynamic:
@@ -509,9 +513,18 @@ class Exporter:
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
elif isinstance(self.model, DetectionModel):
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
+ if self.args.nms: # only batch size is dynamic with NMS
+ dynamic["output0"].pop(2)
+ if self.args.nms and self.model.task == "obb":
+ self.args.opset = opset_version # for NMSModel
+ # OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
+ torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
+ check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
torch.onnx.export(
- self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
+ NMSModel(self.model.cpu() if dynamic else self.model, self.args)
+ if self.args.nms
+ else self.model, # dynamic=True only compatible with cpu
self.im.cpu() if dynamic else self.im,
f,
verbose=False,
@@ -553,7 +566,7 @@ class Exporter:
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
ov_model = ov.convert_model(
- self.model,
+ NMSModel(self.model, self.args) if self.args.nms else self.model,
input=None if self.args.dynamic else [self.im.shape],
example_input=self.im,
)
@@ -736,9 +749,6 @@ class Exporter:
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
if f.is_dir():
shutil.rmtree(f)
- if self.args.nms and getattr(self.model, "end2end", False):
- LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
- self.args.nms = False
bias = [0.0, 0.0, 0.0]
scale = 1 / 255
@@ -1438,8 +1448,8 @@ class Exporter:
nms.coordinatesOutputFeatureName = "coordinates"
nms.iouThresholdInputFeatureName = "iouThreshold"
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
- nms.iouThreshold = 0.45
- nms.confidenceThreshold = 0.25
+ nms.iouThreshold = self.args.iou
+ nms.confidenceThreshold = self.args.conf
nms.pickTop.perClass = True
nms.stringClassLabels.vector.extend(names.values())
nms_model = ct.models.MLModel(nms_spec)
@@ -1507,3 +1517,91 @@ class IOSDetectModel(torch.nn.Module):
"""Normalize predictions of object detection model with input size-dependent factors."""
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
+
+
+class NMSModel(torch.nn.Module):
+ """Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
+
+ def __init__(self, model, args):
+ """
+ Initialize the NMSModel.
+
+ Args:
+ model (torch.nn.module): The model to wrap with NMS postprocessing.
+ args (Namespace): The export arguments.
+ """
+ super().__init__()
+ self.model = model
+ self.args = args
+ self.obb = model.task == "obb"
+ self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
+
+ def forward(self, x):
+ """
+ Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
+
+ Args:
+ x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
+
+ Returns:
+ out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
+ """
+ from functools import partial
+
+ from torchvision.ops import nms
+
+ preds = self.model(x)
+ pred = preds[0] if isinstance(preds, tuple) else preds
+ pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
+ extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
+ boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
+ scores, classes = scores.max(dim=-1)
+ # (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
+ out = torch.zeros(
+ boxes.shape[0],
+ self.args.max_det,
+ boxes.shape[-1] + 2 + extra_shape,
+ device=boxes.device,
+ dtype=boxes.dtype,
+ )
+ for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
+ mask = score > self.args.conf
+ if self.is_tf:
+ # TFLite GatherND error if mask is empty
+ score *= mask
+ # Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
+ mask = score.topk(self.args.max_det * 5).indices
+ box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
+ if not self.obb:
+ box = xywh2xyxy(box)
+ if self.is_tf:
+ # TFlite bug returns less boxes
+ box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
+ nmsbox = box.clone()
+ # `8` is the minimum value experimented to get correct NMS results for obb
+ multiplier = 8 if self.obb else 1
+ # Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
+ if self.args.format == "tflite": # TFLite is already normalized
+ nmsbox *= multiplier
+ else:
+ nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
+ if not self.args.agnostic_nms: # class-specific NMS
+ end = 2 if self.obb else 4
+ # fully explicit expansion otherwise reshape error
+ # large max_wh causes issues when quantizing
+ cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
+ offbox = nmsbox[:, :end] + cls_offset * multiplier
+ nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
+ nms_fn = (
+ partial(nms_rotated, use_triu=not (self.is_tf or (self.args.opset or 14) < 14)) if self.obb else nms
+ )
+ keep = nms_fn(
+ torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
+ score,
+ self.args.iou,
+ )[: self.args.max_det]
+ dets = torch.cat([box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1), extra[keep]], dim=-1)
+ # Zero-pad to max_det size to avoid reshape error
+ pad = (0, 0, 0, self.args.max_det - dets.shape[0])
+ out[i] = torch.nn.functional.pad(dets, pad)
+ return (out, preds[1]) if self.model.task == "segment" else out
diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py
index 9fc9e6e1..b7f7dd72 100644
--- a/ultralytics/engine/results.py
+++ b/ultralytics/engine/results.py
@@ -305,7 +305,7 @@ class Results(SimpleClass):
if v is not None:
return len(v)
- def update(self, boxes=None, masks=None, probs=None, obb=None):
+ def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None):
"""
Updates the Results object with new detection data.
@@ -318,6 +318,7 @@ class Results(SimpleClass):
masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
+ keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.
Examples:
>>> results = model("image.jpg")
@@ -332,6 +333,8 @@ class Results(SimpleClass):
self.probs = probs
if obb is not None:
self.obb = OBB(obb, self.orig_shape)
+ if keypoints is not None:
+ self.keypoints = Keypoints(keypoints, self.orig_shape)
def _apply(self, fn, *args, **kwargs):
"""
diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py
index c3d0f37e..ca01e94e 100644
--- a/ultralytics/models/nas/val.py
+++ b/ultralytics/models/nas/val.py
@@ -38,13 +38,7 @@ class NASValidator(DetectionValidator):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
- return ops.non_max_suppression(
+ return super().postprocess(
preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=False,
- agnostic=self.args.single_cls or self.args.agnostic_nms,
- max_det=self.args.max_det,
max_time_img=0.5,
)
diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py
index 4d9da896..172e54d3 100644
--- a/ultralytics/models/yolo/detect/predict.py
+++ b/ultralytics/models/yolo/detect/predict.py
@@ -20,22 +20,54 @@ class DetectionPredictor(BasePredictor):
```
"""
- def postprocess(self, preds, img, orig_imgs):
+ def postprocess(self, preds, img, orig_imgs, **kwargs):
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
- agnostic=self.args.agnostic_nms,
+ self.args.classes,
+ self.args.agnostic_nms,
max_det=self.args.max_det,
- classes=self.args.classes,
+ nc=len(self.model.names),
+ end2end=getattr(self.model, "end2end", False),
+ rotated=self.args.task == "obb",
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
- results = []
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
- return results
+ return self.construct_results(preds, img, orig_imgs, **kwargs)
+
+ def construct_results(self, preds, img, orig_imgs):
+ """
+ Constructs a list of result objects from the predictions.
+
+ Args:
+ preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
+ img (torch.Tensor): The image after preprocessing.
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
+
+ Returns:
+ (list): List of result objects containing the original images, image paths, class names, and bounding boxes.
+ """
+ return [
+ self.construct_result(pred, img, orig_img, img_path)
+ for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
+ ]
+
+ def construct_result(self, pred, img, orig_img, img_path):
+ """
+ Constructs the result object from the prediction.
+
+ Args:
+ pred (torch.Tensor): The predicted bounding boxes and scores.
+ img (torch.Tensor): The image after preprocessing.
+ orig_img (np.ndarray): The original image before preprocessing.
+ img_path (str): The path to the original image.
+
+ Returns:
+ (Results): The result object containing the original image, image path, class names, and bounding boxes.
+ """
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py
index d5fcbfe5..ec809d55 100644
--- a/ultralytics/models/yolo/detect/val.py
+++ b/ultralytics/models/yolo/detect/val.py
@@ -78,6 +78,7 @@ class DetectionValidator(BaseValidator):
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
self.names = model.names
self.nc = len(model.names)
+ self.end2end = getattr(model, "end2end", False)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
@@ -96,9 +97,12 @@ class DetectionValidator(BaseValidator):
self.args.conf,
self.args.iou,
labels=self.lb,
+ nc=self.nc,
multi_label=True,
agnostic=self.args.single_cls or self.args.agnostic_nms,
max_det=self.args.max_det,
+ end2end=self.end2end,
+ rotated=self.args.task == "obb",
)
def _prepare_batch(self, si, batch):
diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py
index ebbd7530..ef6214d4 100644
--- a/ultralytics/models/yolo/obb/predict.py
+++ b/ultralytics/models/yolo/obb/predict.py
@@ -27,27 +27,20 @@ class OBBPredictor(DetectionPredictor):
super().__init__(cfg, overrides, _callbacks)
self.args.task = "obb"
- def postprocess(self, preds, img, orig_imgs):
- """Post-processes predictions and returns a list of Results objects."""
- preds = ops.non_max_suppression(
- preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=len(self.model.names),
- classes=self.args.classes,
- rotated=True,
- )
+ def construct_result(self, pred, img, orig_img, img_path):
+ """
+ Constructs the result object from the prediction.
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+ Args:
+ pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
+ img (torch.Tensor): The image after preprocessing.
+ orig_img (np.ndarray): The original image before preprocessing.
+ img_path (str): The path to the original image.
- results = []
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
- rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
- rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
- # xywh, r, conf, cls
- obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
- results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
- return results
+ Returns:
+ (Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
+ """
+ rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
+ rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
+ obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
+ return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py
index a8392858..b5cb89f1 100644
--- a/ultralytics/models/yolo/obb/val.py
+++ b/ultralytics/models/yolo/obb/val.py
@@ -36,20 +36,6 @@ class OBBValidator(DetectionValidator):
val = self.data.get(self.args.split, "") # validation path
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
- def postprocess(self, preds):
- """Apply Non-maximum suppression to prediction outputs."""
- return ops.non_max_suppression(
- preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- nc=self.nc,
- multi_label=True,
- agnostic=self.args.single_cls or self.args.agnostic_nms,
- max_det=self.args.max_det,
- rotated=True,
- )
-
def _process_batch(self, detections, gt_bboxes, gt_cls):
"""
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py
index 4d9315b8..75334b75 100644
--- a/ultralytics/models/yolo/pose/predict.py
+++ b/ultralytics/models/yolo/pose/predict.py
@@ -1,6 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
@@ -30,27 +29,21 @@ class PosePredictor(DetectionPredictor):
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
- def postprocess(self, preds, img, orig_imgs):
- """Return detection results for a given input image or list of images."""
- preds = ops.non_max_suppression(
- preds,
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- classes=self.args.classes,
- nc=len(self.model.names),
- )
+ def construct_result(self, pred, img, orig_img, img_path):
+ """
+ Constructs the result object from the prediction.
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+ Args:
+ pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
+ img (torch.Tensor): The image after preprocessing.
+ orig_img (np.ndarray): The original image before preprocessing.
+ img_path (str): The path to the original image.
- results = []
- for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
- pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
- pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
- results.append(
- Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
- )
- return results
+ Returns:
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
+ """
+ result = super().construct_result(pred, img, orig_img, img_path)
+ pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
+ pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
+ result.update(keypoints=pred_kpts)
+ return result
diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py
index 2acdaa3e..9fc872f9 100644
--- a/ultralytics/models/yolo/pose/val.py
+++ b/ultralytics/models/yolo/pose/val.py
@@ -61,19 +61,6 @@ class PoseValidator(DetectionValidator):
"mAP50-95)",
)
- def postprocess(self, preds):
- """Apply non-maximum suppression and return detections with high confidence scores."""
- return ops.non_max_suppression(
- preds,
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=True,
- agnostic=self.args.single_cls or self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=self.nc,
- )
-
def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model)
diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py
index 78281625..4e0adc7c 100644
--- a/ultralytics/models/yolo/segment/predict.py
+++ b/ultralytics/models/yolo/segment/predict.py
@@ -27,29 +27,48 @@ class SegmentationPredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_imgs):
"""Applies non-max suppression and processes detections for each image in an input batch."""
- p = ops.non_max_suppression(
- preds[0],
- self.args.conf,
- self.args.iou,
- agnostic=self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=len(self.model.names),
- classes=self.args.classes,
- )
+ # tuple if PyTorch model or array if exported
+ protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
+ return super().postprocess(preds[0], img, orig_imgs, protos=protos)
- if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
- orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
+ def construct_results(self, preds, img, orig_imgs, protos):
+ """
+ Constructs a list of result objects from the predictions.
- results = []
- proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
- for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
- if not len(pred): # save empty boxes
- masks = None
- elif self.args.retina_masks:
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
- masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
- else:
- masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
- pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
- results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
- return results
+ Args:
+ preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
+ img (torch.Tensor): The image after preprocessing.
+ orig_imgs (List[np.ndarray]): List of original images before preprocessing.
+ protos (List[torch.Tensor]): List of prototype masks.
+
+ Returns:
+ (list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
+ """
+ return [
+ self.construct_result(pred, img, orig_img, img_path, proto)
+ for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
+ ]
+
+ def construct_result(self, pred, img, orig_img, img_path, proto):
+ """
+ Constructs the result object from the prediction.
+
+ Args:
+ pred (np.ndarray): The predicted bounding boxes, scores, and masks.
+ img (torch.Tensor): The image after preprocessing.
+ orig_img (np.ndarray): The original image before preprocessing.
+ img_path (str): The path to the original image.
+ proto (torch.Tensor): The prototype masks.
+
+ Returns:
+ (Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
+ """
+ if not len(pred): # save empty boxes
+ masks = None
+ elif self.args.retina_masks:
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+ masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
+ else:
+ masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
+ pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
+ return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py
index 81f3d01b..8be870e3 100644
--- a/ultralytics/models/yolo/segment/val.py
+++ b/ultralytics/models/yolo/segment/val.py
@@ -70,16 +70,7 @@ class SegmentationValidator(DetectionValidator):
def postprocess(self, preds):
"""Post-processes YOLO predictions and returns output detections with proto."""
- p = ops.non_max_suppression(
- preds[0],
- self.args.conf,
- self.args.iou,
- labels=self.lb,
- multi_label=True,
- agnostic=self.args.single_cls or self.args.agnostic_nms,
- max_det=self.args.max_det,
- nc=self.nc,
- )
+ p = super().postprocess(preds[0])
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
return p, proto
diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py
index c63f999d..f536a602 100644
--- a/ultralytics/nn/autobackend.py
+++ b/ultralytics/nn/autobackend.py
@@ -132,6 +132,7 @@ class AutoBackend(nn.Module):
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
stride = 32 # default stride
+ end2end = False # default end2end
model, metadata, task = None, None, None
# Set device
@@ -222,16 +223,18 @@ class AutoBackend(nn.Module):
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
+ fp16 = True if "float16" in session.get_inputs()[0].type else False
if not dynamic:
io = session.io_binding()
bindings = []
for output in session.get_outputs():
- y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device)
+ out_fp16 = "float16" in output.type
+ y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
io.bind_output(
name=output.name,
device_type=device.type,
device_id=device.index if cuda else 0,
- element_type=np.float16 if fp16 else np.float32,
+ element_type=np.float16 if out_fp16 else np.float32,
shape=tuple(y_tensor.shape),
buffer_ptr=y_tensor.data_ptr(),
)
@@ -501,7 +504,7 @@ class AutoBackend(nn.Module):
for k, v in metadata.items():
if k in {"stride", "batch"}:
metadata[k] = int(v)
- elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
+ elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
metadata[k] = eval(v)
stride = metadata["stride"]
task = metadata["task"]
@@ -509,6 +512,7 @@ class AutoBackend(nn.Module):
imgsz = metadata["imgsz"]
names = metadata["names"]
kpt_shape = metadata.get("kpt_shape")
+ end2end = metadata.get("args", {}).get("nms", False)
elif not (pt or triton or nn_module):
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
@@ -703,9 +707,12 @@ class AutoBackend(nn.Module):
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
- if x.shape[-1] == 6: # end-to-end model
+ if x.shape[-1] == 6 or self.end2end: # end-to-end model
x[:, :, [0, 2]] *= w
x[:, :, [1, 3]] *= h
+ if self.task == "pose":
+ x[:, :, 6::3] *= w
+ x[:, :, 7::3] *= h
else:
x[:, [0, 2]] *= w
x[:, [1, 3]] *= h
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 52b51552..af41ffee 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -143,7 +143,7 @@ def make_divisible(x, divisor):
return math.ceil(x / divisor) * divisor
-def nms_rotated(boxes, scores, threshold=0.45):
+def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
"""
NMS for oriented bounding boxes using probiou and fast-nms.
@@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
scores (torch.Tensor): Confidence scores, shape (N,).
threshold (float, optional): IoU threshold. Defaults to 0.45.
+ use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
+ when exporting obb models to some formats that do not support `torch.triu`.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
"""
- if len(boxes) == 0:
- return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
- ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
- pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
+ ious = batch_probiou(boxes, boxes)
+ if use_triu:
+ ious = ious.triu_(diagonal=1)
+ # pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
+ # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
+ pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
+ else:
+ n = boxes.shape[0]
+ row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
+ col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
+ upper_mask = row_idx < col_idx
+ ious = ious * upper_mask
+ # Zeroing these scores ensures the additional indices would not affect the final results
+ scores[~((ious >= threshold).sum(0) <= 0)] = 0
+ # NOTE: return indices with fixed length to avoid TFLite reshape error
+ pick = torch.topk(scores, scores.shape[0]).indices
return sorted_idx[pick]
@@ -179,6 +193,7 @@ def non_max_suppression(
max_wh=7680,
in_place=True,
rotated=False,
+ end2end=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@@ -205,6 +220,7 @@ def non_max_suppression(
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
+ end2end (bool): If the model doesn't require NMS.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
@@ -221,7 +237,7 @@ def non_max_suppression(
if classes is not None:
classes = torch.tensor(classes, device=prediction.device)
- if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
+ if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
if classes is not None:
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]