ultralytics 8.3.67 NMS Export for Detect, Segment, Pose and OBB YOLO models (#18484)
Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
0e48a00303
commit
9181ff62f5
17 changed files with 320 additions and 208 deletions
|
|
@ -1,16 +1,16 @@
|
|||
| Format | `format` Argument | Model | Metadata | Arguments |
|
||||
| ------------------------------------------------- | ----------------- | ----------------------------------------------- | -------- | -------------------------------------------------------------------- |
|
||||
| ------------------------------------------------- | ----------------- | ----------------------------------------------- | -------- | --------------------------------------------------------------------------- |
|
||||
| [PyTorch](https://pytorch.org/) | - | `{{ model_name or "yolo11n" }}.pt` | ✅ | - |
|
||||
| [TorchScript](../integrations/torchscript.md) | `torchscript` | `{{ model_name or "yolo11n" }}.torchscript` | ✅ | `imgsz`, `optimize`, `batch` |
|
||||
| [ONNX](../integrations/onnx.md) | `onnx` | `{{ model_name or "yolo11n" }}.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset`, `batch` |
|
||||
| [OpenVINO](../integrations/openvino.md) | `openvino` | `{{ model_name or "yolo11n" }}_openvino_model/` | ✅ | `imgsz`, `half`, `dynamic`, `int8`, `batch` |
|
||||
| [TensorRT](../integrations/tensorrt.md) | `engine` | `{{ model_name or "yolo11n" }}.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace`, `int8`, `batch` |
|
||||
| [TorchScript](../integrations/torchscript.md) | `torchscript` | `{{ model_name or "yolo11n" }}.torchscript` | ✅ | `imgsz`, `optimize`, `nms`, `batch` |
|
||||
| [ONNX](../integrations/onnx.md) | `onnx` | `{{ model_name or "yolo11n" }}.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset`, `nms`, `batch` |
|
||||
| [OpenVINO](../integrations/openvino.md) | `openvino` | `{{ model_name or "yolo11n" }}_openvino_model/` | ✅ | `imgsz`, `half`, `dynamic`, `int8`, `nms`, `batch` |
|
||||
| [TensorRT](../integrations/tensorrt.md) | `engine` | `{{ model_name or "yolo11n" }}.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace`, `int8`, `nms`, `batch` |
|
||||
| [CoreML](../integrations/coreml.md) | `coreml` | `{{ model_name or "yolo11n" }}.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
|
||||
| [TF SavedModel](../integrations/tf-savedmodel.md) | `saved_model` | `{{ model_name or "yolo11n" }}_saved_model/` | ✅ | `imgsz`, `keras`, `int8`, `batch` |
|
||||
| [TF SavedModel](../integrations/tf-savedmodel.md) | `saved_model` | `{{ model_name or "yolo11n" }}_saved_model/` | ✅ | `imgsz`, `keras`, `int8`, `nms`, `batch` |
|
||||
| [TF GraphDef](../integrations/tf-graphdef.md) | `pb` | `{{ model_name or "yolo11n" }}.pb` | ❌ | `imgsz`, `batch` |
|
||||
| [TF Lite](../integrations/tflite.md) | `tflite` | `{{ model_name or "yolo11n" }}.tflite` | ✅ | `imgsz`, `half`, `int8`, `batch` |
|
||||
| [TF Lite](../integrations/tflite.md) | `tflite` | `{{ model_name or "yolo11n" }}.tflite` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
|
||||
| [TF Edge TPU](../integrations/edge-tpu.md) | `edgetpu` | `{{ model_name or "yolo11n" }}_edgetpu.tflite` | ✅ | `imgsz` |
|
||||
| [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `batch` |
|
||||
| [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `nms`, `batch` |
|
||||
| [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` |
|
||||
| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` |
|
||||
| [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` |
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ keywords: YOLOv8, export formats, ONNX, TensorRT, CoreML, machine learning model
|
|||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.engine.exporter.NMSModel
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.engine.exporter.export_formats
|
||||
|
||||
<br><br><hr><br>
|
||||
|
|
|
|||
|
|
@ -43,23 +43,19 @@ def test_export_openvino():
|
|||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13")
|
||||
@pytest.mark.parametrize(
|
||||
"task, dynamic, int8, half, batch",
|
||||
"task, dynamic, int8, half, batch, nms",
|
||||
[ # generate all combinations but exclude those where both int8 and half are True
|
||||
(task, dynamic, int8, half, batch)
|
||||
for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2])
|
||||
(task, dynamic, int8, half, batch, nms)
|
||||
for task, dynamic, int8, half, batch, nms in product(
|
||||
TASKS, [True, False], [True, False], [True, False], [1, 2], [True, False]
|
||||
)
|
||||
if not (int8 and half) # exclude cases where both int8 and half are True
|
||||
],
|
||||
)
|
||||
def test_export_openvino_matrix(task, dynamic, int8, half, batch):
|
||||
def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
|
||||
"""Test YOLO model exports to OpenVINO under various configuration matrix conditions."""
|
||||
file = YOLO(TASK2MODEL[task]).export(
|
||||
format="openvino",
|
||||
imgsz=32,
|
||||
dynamic=dynamic,
|
||||
int8=int8,
|
||||
half=half,
|
||||
batch=batch,
|
||||
data=TASK2DATA[task],
|
||||
format="openvino", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, data=TASK2DATA[task], nms=nms
|
||||
)
|
||||
if WINDOWS:
|
||||
# Use unique filenames due to Windows file permissions bug possibly due to latent threaded use
|
||||
|
|
@ -72,34 +68,26 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch):
|
|||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False])
|
||||
"task, dynamic, int8, half, batch, simplify, nms",
|
||||
product(TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]),
|
||||
)
|
||||
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify):
|
||||
def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
|
||||
"""Test YOLO exports to ONNX format with various configurations and parameters."""
|
||||
file = YOLO(TASK2MODEL[task]).export(
|
||||
format="onnx",
|
||||
imgsz=32,
|
||||
dynamic=dynamic,
|
||||
int8=int8,
|
||||
half=half,
|
||||
batch=batch,
|
||||
simplify=simplify,
|
||||
format="onnx", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, simplify=simplify, nms=nms
|
||||
)
|
||||
YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference
|
||||
Path(file).unlink() # cleanup
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2]))
|
||||
def test_export_torchscript_matrix(task, dynamic, int8, half, batch):
|
||||
@pytest.mark.parametrize(
|
||||
"task, dynamic, int8, half, batch, nms", product(TASKS, [False], [False], [False], [1, 2], [True, False])
|
||||
)
|
||||
def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):
|
||||
"""Tests YOLO model exports to TorchScript format under varied configurations."""
|
||||
file = YOLO(TASK2MODEL[task]).export(
|
||||
format="torchscript",
|
||||
imgsz=32,
|
||||
dynamic=dynamic,
|
||||
int8=int8,
|
||||
half=half,
|
||||
batch=batch,
|
||||
format="torchscript", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
|
||||
)
|
||||
YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3
|
||||
Path(file).unlink() # cleanup
|
||||
|
|
@ -135,22 +123,19 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
|
|||
@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10")
|
||||
@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS")
|
||||
@pytest.mark.parametrize(
|
||||
"task, dynamic, int8, half, batch",
|
||||
"task, dynamic, int8, half, batch, nms",
|
||||
[ # generate all combinations but exclude those where both int8 and half are True
|
||||
(task, dynamic, int8, half, batch)
|
||||
for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
|
||||
(task, dynamic, int8, half, batch, nms)
|
||||
for task, dynamic, int8, half, batch, nms in product(
|
||||
TASKS, [False], [True, False], [True, False], [1], [True, False]
|
||||
)
|
||||
if not (int8 and half) # exclude cases where both int8 and half are True
|
||||
],
|
||||
)
|
||||
def test_export_tflite_matrix(task, dynamic, int8, half, batch):
|
||||
def test_export_tflite_matrix(task, dynamic, int8, half, batch, nms):
|
||||
"""Test YOLO exports to TFLite format considering various export configurations."""
|
||||
file = YOLO(TASK2MODEL[task]).export(
|
||||
format="tflite",
|
||||
imgsz=32,
|
||||
dynamic=dynamic,
|
||||
int8=int8,
|
||||
half=half,
|
||||
batch=batch,
|
||||
format="tflite", imgsz=32, dynamic=dynamic, int8=int8, half=half, batch=batch, nms=nms
|
||||
)
|
||||
YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
|
||||
Path(file).unlink() # cleanup
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.3.66"
|
||||
__version__ = "8.3.67"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
|
|||
)
|
||||
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
||||
from ultralytics.utils.files import file_size, spaces_in_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
||||
|
||||
|
||||
|
|
@ -111,16 +111,16 @@ def export_formats():
|
|||
"""Ultralytics YOLO export formats."""
|
||||
x = [
|
||||
["PyTorch", "-", ".pt", True, True, []],
|
||||
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
|
||||
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
|
||||
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
|
||||
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
|
||||
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
|
||||
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
|
||||
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
|
||||
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
|
||||
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
|
||||
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
|
||||
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
|
||||
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
|
||||
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
|
||||
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
|
||||
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
|
||||
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
|
||||
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
|
||||
|
|
@ -281,6 +281,11 @@ class Exporter:
|
|||
)
|
||||
if self.args.int8 and tflite:
|
||||
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
|
||||
if self.args.nms:
|
||||
if getattr(model, "end2end", False):
|
||||
LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
||||
self.args.nms = False
|
||||
self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
|
||||
if edgetpu:
|
||||
if not LINUX:
|
||||
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
|
||||
|
|
@ -344,8 +349,8 @@ class Exporter:
|
|||
)
|
||||
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
for _ in range(2): # dry runs
|
||||
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
|
||||
if self.args.half and onnx and self.device.type != "cpu":
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
|
||||
|
|
@ -476,7 +481,7 @@ class Exporter:
|
|||
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
||||
f = self.file.with_suffix(".torchscript")
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
|
||||
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f"{prefix} optimizing for mobile...")
|
||||
|
|
@ -499,7 +504,6 @@ class Exporter:
|
|||
opset_version = self.args.opset or get_latest_opset()
|
||||
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
|
||||
f = str(self.file.with_suffix(".onnx"))
|
||||
|
||||
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
|
||||
dynamic = self.args.dynamic
|
||||
if dynamic:
|
||||
|
|
@ -509,9 +513,18 @@ class Exporter:
|
|||
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
|
||||
elif isinstance(self.model, DetectionModel):
|
||||
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
|
||||
if self.args.nms: # only batch size is dynamic with NMS
|
||||
dynamic["output0"].pop(2)
|
||||
if self.args.nms and self.model.task == "obb":
|
||||
self.args.opset = opset_version # for NMSModel
|
||||
# OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
|
||||
torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
|
||||
check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
|
||||
|
||||
torch.onnx.export(
|
||||
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
|
||||
NMSModel(self.model.cpu() if dynamic else self.model, self.args)
|
||||
if self.args.nms
|
||||
else self.model, # dynamic=True only compatible with cpu
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
|
|
@ -553,7 +566,7 @@ class Exporter:
|
|||
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
||||
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
||||
ov_model = ov.convert_model(
|
||||
self.model,
|
||||
NMSModel(self.model, self.args) if self.args.nms else self.model,
|
||||
input=None if self.args.dynamic else [self.im.shape],
|
||||
example_input=self.im,
|
||||
)
|
||||
|
|
@ -736,9 +749,6 @@ class Exporter:
|
|||
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
||||
if f.is_dir():
|
||||
shutil.rmtree(f)
|
||||
if self.args.nms and getattr(self.model, "end2end", False):
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
||||
self.args.nms = False
|
||||
|
||||
bias = [0.0, 0.0, 0.0]
|
||||
scale = 1 / 255
|
||||
|
|
@ -1438,8 +1448,8 @@ class Exporter:
|
|||
nms.coordinatesOutputFeatureName = "coordinates"
|
||||
nms.iouThresholdInputFeatureName = "iouThreshold"
|
||||
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
|
||||
nms.iouThreshold = 0.45
|
||||
nms.confidenceThreshold = 0.25
|
||||
nms.iouThreshold = self.args.iou
|
||||
nms.confidenceThreshold = self.args.conf
|
||||
nms.pickTop.perClass = True
|
||||
nms.stringClassLabels.vector.extend(names.values())
|
||||
nms_model = ct.models.MLModel(nms_spec)
|
||||
|
|
@ -1507,3 +1517,91 @@ class IOSDetectModel(torch.nn.Module):
|
|||
"""Normalize predictions of object detection model with input size-dependent factors."""
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
class NMSModel(torch.nn.Module):
|
||||
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
|
||||
|
||||
def __init__(self, model, args):
|
||||
"""
|
||||
Initialize the NMSModel.
|
||||
|
||||
Args:
|
||||
model (torch.nn.module): The model to wrap with NMS postprocessing.
|
||||
args (Namespace): The export arguments.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.obb = model.task == "obb"
|
||||
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
||||
|
||||
Args:
|
||||
x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from torchvision.ops import nms
|
||||
|
||||
preds = self.model(x)
|
||||
pred = preds[0] if isinstance(preds, tuple) else preds
|
||||
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
|
||||
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
|
||||
scores, classes = scores.max(dim=-1)
|
||||
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
||||
out = torch.zeros(
|
||||
boxes.shape[0],
|
||||
self.args.max_det,
|
||||
boxes.shape[-1] + 2 + extra_shape,
|
||||
device=boxes.device,
|
||||
dtype=boxes.dtype,
|
||||
)
|
||||
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
|
||||
mask = score > self.args.conf
|
||||
if self.is_tf:
|
||||
# TFLite GatherND error if mask is empty
|
||||
score *= mask
|
||||
# Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
|
||||
mask = score.topk(self.args.max_det * 5).indices
|
||||
box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
|
||||
if not self.obb:
|
||||
box = xywh2xyxy(box)
|
||||
if self.is_tf:
|
||||
# TFlite bug returns less boxes
|
||||
box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
|
||||
nmsbox = box.clone()
|
||||
# `8` is the minimum value experimented to get correct NMS results for obb
|
||||
multiplier = 8 if self.obb else 1
|
||||
# Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
|
||||
if self.args.format == "tflite": # TFLite is already normalized
|
||||
nmsbox *= multiplier
|
||||
else:
|
||||
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
|
||||
if not self.args.agnostic_nms: # class-specific NMS
|
||||
end = 2 if self.obb else 4
|
||||
# fully explicit expansion otherwise reshape error
|
||||
# large max_wh causes issues when quantizing
|
||||
cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
|
||||
offbox = nmsbox[:, :end] + cls_offset * multiplier
|
||||
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
|
||||
nms_fn = (
|
||||
partial(nms_rotated, use_triu=not (self.is_tf or (self.args.opset or 14) < 14)) if self.obb else nms
|
||||
)
|
||||
keep = nms_fn(
|
||||
torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
|
||||
score,
|
||||
self.args.iou,
|
||||
)[: self.args.max_det]
|
||||
dets = torch.cat([box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1), extra[keep]], dim=-1)
|
||||
# Zero-pad to max_det size to avoid reshape error
|
||||
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
|
||||
out[i] = torch.nn.functional.pad(dets, pad)
|
||||
return (out, preds[1]) if self.model.task == "segment" else out
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class Results(SimpleClass):
|
|||
if v is not None:
|
||||
return len(v)
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None):
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None):
|
||||
"""
|
||||
Updates the Results object with new detection data.
|
||||
|
||||
|
|
@ -318,6 +318,7 @@ class Results(SimpleClass):
|
|||
masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
|
||||
probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
|
||||
obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
|
||||
keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -332,6 +333,8 @@ class Results(SimpleClass):
|
|||
self.probs = probs
|
||||
if obb is not None:
|
||||
self.obb = OBB(obb, self.orig_shape)
|
||||
if keypoints is not None:
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -38,13 +38,7 @@ class NASValidator(DetectionValidator):
|
|||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0])
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
|
||||
return ops.non_max_suppression(
|
||||
return super().postprocess(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=False,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
max_time_img=0.5,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,22 +20,54 @@ class DetectionPredictor(BasePredictor):
|
|||
```
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
||||
"""Post-processes predictions and returns a list of Results objects."""
|
||||
preds = ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
self.args.classes,
|
||||
self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes,
|
||||
nc=len(self.model.names),
|
||||
end2end=getattr(self.model, "end2end", False),
|
||||
rotated=self.args.task == "obb",
|
||||
)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
||||
return self.construct_results(preds, img, orig_imgs, **kwargs)
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Constructs a list of result objects from the predictions.
|
||||
|
||||
Args:
|
||||
preds (List[torch.Tensor]): List of predicted bounding boxes and scores.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_imgs (List[np.ndarray]): List of original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list): List of result objects containing the original images, image paths, class names, and bounding boxes.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Constructs the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes and scores.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, and bounding boxes.
|
||||
"""
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ class DetectionValidator(BaseValidator):
|
|||
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.end2end = getattr(model, "end2end", False)
|
||||
self.metrics.names = self.names
|
||||
self.metrics.plot = self.args.plots
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
|
||||
|
|
@ -96,9 +97,12 @@ class DetectionValidator(BaseValidator):
|
|||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
nc=self.nc,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
end2end=self.end2end,
|
||||
rotated=self.args.task == "obb",
|
||||
)
|
||||
|
||||
def _prepare_batch(self, si, batch):
|
||||
|
|
|
|||
|
|
@ -27,27 +27,20 @@ class OBBPredictor(DetectionPredictor):
|
|||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "obb"
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-processes predictions and returns a list of Results objects."""
|
||||
preds = ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes,
|
||||
rotated=True,
|
||||
)
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Constructs the result object from the prediction.
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
results = []
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, and oriented bounding boxes.
|
||||
"""
|
||||
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
||||
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
||||
# xywh, r, conf, cls
|
||||
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb))
|
||||
return results
|
||||
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
|
||||
|
|
|
|||
|
|
@ -36,20 +36,6 @@ class OBBValidator(DetectionValidator):
|
|||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
return ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
nc=self.nc,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
rotated=True,
|
||||
)
|
||||
|
||||
def _process_batch(self, detections, gt_bboxes, gt_cls):
|
||||
"""
|
||||
Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
||||
|
||||
|
|
@ -30,27 +29,21 @@ class PosePredictor(DetectionPredictor):
|
|||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Return detection results for a given input image or list of images."""
|
||||
preds = ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
classes=self.args.classes,
|
||||
nc=len(self.model.names),
|
||||
)
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Constructs the result object from the prediction.
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
results = []
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round()
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, bounding boxes, and keypoints.
|
||||
"""
|
||||
result = super().construct_result(pred, img, orig_img, img_path)
|
||||
pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:]
|
||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||
results.append(
|
||||
Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts)
|
||||
)
|
||||
return results
|
||||
result.update(keypoints=pred_kpts)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -61,19 +61,6 @@ class PoseValidator(DetectionValidator):
|
|||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
||||
return ops.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc,
|
||||
)
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initiate pose estimation metrics for YOLO model."""
|
||||
super().init_metrics(model)
|
||||
|
|
|
|||
|
|
@ -27,29 +27,48 @@ class SegmentationPredictor(DetectionPredictor):
|
|||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Applies non-max suppression and processes detections for each image in an input batch."""
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes,
|
||||
)
|
||||
# tuple if PyTorch model or array if exported
|
||||
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
||||
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
def construct_results(self, preds, img, orig_imgs, protos):
|
||||
"""
|
||||
Constructs a list of result objects from the predictions.
|
||||
|
||||
results = []
|
||||
proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported
|
||||
for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
|
||||
Args:
|
||||
preds (List[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_imgs (List[np.ndarray]): List of original images before preprocessing.
|
||||
protos (List[torch.Tensor]): List of prototype masks.
|
||||
|
||||
Returns:
|
||||
(list): List of result objects containing the original images, image paths, class names, bounding boxes, and masks.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path, proto)
|
||||
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path, proto):
|
||||
"""
|
||||
Constructs the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (np.ndarray): The predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
proto (torch.Tensor): The prototype masks.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, bounding boxes, and masks.
|
||||
"""
|
||||
if not len(pred): # save empty boxes
|
||||
masks = None
|
||||
elif self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
|
||||
return results
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
||||
|
|
|
|||
|
|
@ -70,16 +70,7 @@ class SegmentationValidator(DetectionValidator):
|
|||
|
||||
def postprocess(self, preds):
|
||||
"""Post-processes YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
labels=self.lb,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=self.nc,
|
||||
)
|
||||
p = super().postprocess(preds[0])
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
return p, proto
|
||||
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ class AutoBackend(nn.Module):
|
|||
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
|
||||
nhwc = coreml or saved_model or pb or tflite or edgetpu or rknn # BHWC formats (vs torch BCWH)
|
||||
stride = 32 # default stride
|
||||
end2end = False # default end2end
|
||||
model, metadata, task = None, None, None
|
||||
|
||||
# Set device
|
||||
|
|
@ -222,16 +223,18 @@ class AutoBackend(nn.Module):
|
|||
output_names = [x.name for x in session.get_outputs()]
|
||||
metadata = session.get_modelmeta().custom_metadata_map
|
||||
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
|
||||
fp16 = True if "float16" in session.get_inputs()[0].type else False
|
||||
if not dynamic:
|
||||
io = session.io_binding()
|
||||
bindings = []
|
||||
for output in session.get_outputs():
|
||||
y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device)
|
||||
out_fp16 = "float16" in output.type
|
||||
y_tensor = torch.empty(output.shape, dtype=torch.float16 if out_fp16 else torch.float32).to(device)
|
||||
io.bind_output(
|
||||
name=output.name,
|
||||
device_type=device.type,
|
||||
device_id=device.index if cuda else 0,
|
||||
element_type=np.float16 if fp16 else np.float32,
|
||||
element_type=np.float16 if out_fp16 else np.float32,
|
||||
shape=tuple(y_tensor.shape),
|
||||
buffer_ptr=y_tensor.data_ptr(),
|
||||
)
|
||||
|
|
@ -501,7 +504,7 @@ class AutoBackend(nn.Module):
|
|||
for k, v in metadata.items():
|
||||
if k in {"stride", "batch"}:
|
||||
metadata[k] = int(v)
|
||||
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
|
||||
elif k in {"imgsz", "names", "kpt_shape", "args"} and isinstance(v, str):
|
||||
metadata[k] = eval(v)
|
||||
stride = metadata["stride"]
|
||||
task = metadata["task"]
|
||||
|
|
@ -509,6 +512,7 @@ class AutoBackend(nn.Module):
|
|||
imgsz = metadata["imgsz"]
|
||||
names = metadata["names"]
|
||||
kpt_shape = metadata.get("kpt_shape")
|
||||
end2end = metadata.get("args", {}).get("nms", False)
|
||||
elif not (pt or triton or nn_module):
|
||||
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
||||
|
||||
|
|
@ -703,9 +707,12 @@ class AutoBackend(nn.Module):
|
|||
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
||||
# Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695
|
||||
# xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models
|
||||
if x.shape[-1] == 6: # end-to-end model
|
||||
if x.shape[-1] == 6 or self.end2end: # end-to-end model
|
||||
x[:, :, [0, 2]] *= w
|
||||
x[:, :, [1, 3]] *= h
|
||||
if self.task == "pose":
|
||||
x[:, :, 6::3] *= w
|
||||
x[:, :, 7::3] *= h
|
||||
else:
|
||||
x[:, [0, 2]] *= w
|
||||
x[:, [1, 3]] *= h
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ def make_divisible(x, divisor):
|
|||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
|
||||
"""
|
||||
NMS for oriented bounding boxes using probiou and fast-nms.
|
||||
|
||||
|
|
@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
|
|||
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
|
||||
scores (torch.Tensor): Confidence scores, shape (N,).
|
||||
threshold (float, optional): IoU threshold. Defaults to 0.45.
|
||||
use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
|
||||
when exporting obb models to some formats that do not support `torch.triu`.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
ious = batch_probiou(boxes, boxes)
|
||||
if use_triu:
|
||||
ious = ious.triu_(diagonal=1)
|
||||
# pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
||||
pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
|
||||
else:
|
||||
n = boxes.shape[0]
|
||||
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
|
||||
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
|
||||
upper_mask = row_idx < col_idx
|
||||
ious = ious * upper_mask
|
||||
# Zeroing these scores ensures the additional indices would not affect the final results
|
||||
scores[~((ious >= threshold).sum(0) <= 0)] = 0
|
||||
# NOTE: return indices with fixed length to avoid TFLite reshape error
|
||||
pick = torch.topk(scores, scores.shape[0]).indices
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
|
|
@ -179,6 +193,7 @@ def non_max_suppression(
|
|||
max_wh=7680,
|
||||
in_place=True,
|
||||
rotated=False,
|
||||
end2end=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
|
@ -205,6 +220,7 @@ def non_max_suppression(
|
|||
max_wh (int): The maximum box width and height in pixels.
|
||||
in_place (bool): If True, the input prediction tensor will be modified in place.
|
||||
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
|
||||
end2end (bool): If the model doesn't require NMS.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
|
|
@ -221,7 +237,7 @@ def non_max_suppression(
|
|||
if classes is not None:
|
||||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue