ultralytics 8.3.67 NMS Export for Detect, Segment, Pose and OBB YOLO models (#18484)

Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Mohammed Yasin 2025-01-24 18:00:36 +08:00 committed by GitHub
parent 0e48a00303
commit 9181ff62f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 320 additions and 208 deletions

View file

@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
)
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
from ultralytics.utils.files import file_size, spaces_in_path
from ultralytics.utils.ops import Profile
from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
@ -111,16 +111,16 @@ def export_formats():
"""Ultralytics YOLO export formats."""
x = [
["PyTorch", "-", ".pt", True, True, []],
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
@ -281,6 +281,11 @@ class Exporter:
)
if self.args.int8 and tflite:
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
if self.args.nms:
if getattr(model, "end2end", False):
LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
self.args.nms = False
self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
if edgetpu:
if not LINUX:
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
@ -344,8 +349,8 @@ class Exporter:
)
y = None
for _ in range(2):
y = model(im) # dry runs
for _ in range(2): # dry runs
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
if self.args.half and onnx and self.device.type != "cpu":
im, model = im.half(), model.half() # to FP16
@ -476,7 +481,7 @@ class Exporter:
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
f = self.file.with_suffix(".torchscript")
ts = torch.jit.trace(self.model, self.im, strict=False)
ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f"{prefix} optimizing for mobile...")
@ -499,7 +504,6 @@ class Exporter:
opset_version = self.args.opset or get_latest_opset()
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
f = str(self.file.with_suffix(".onnx"))
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
dynamic = self.args.dynamic
if dynamic:
@ -509,9 +513,18 @@ class Exporter:
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
elif isinstance(self.model, DetectionModel):
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
if self.args.nms: # only batch size is dynamic with NMS
dynamic["output0"].pop(2)
if self.args.nms and self.model.task == "obb":
self.args.opset = opset_version # for NMSModel
# OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
torch.onnx.export(
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
NMSModel(self.model.cpu() if dynamic else self.model, self.args)
if self.args.nms
else self.model, # dynamic=True only compatible with cpu
self.im.cpu() if dynamic else self.im,
f,
verbose=False,
@ -553,7 +566,7 @@ class Exporter:
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
ov_model = ov.convert_model(
self.model,
NMSModel(self.model, self.args) if self.args.nms else self.model,
input=None if self.args.dynamic else [self.im.shape],
example_input=self.im,
)
@ -736,9 +749,6 @@ class Exporter:
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
if f.is_dir():
shutil.rmtree(f)
if self.args.nms and getattr(self.model, "end2end", False):
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
self.args.nms = False
bias = [0.0, 0.0, 0.0]
scale = 1 / 255
@ -1438,8 +1448,8 @@ class Exporter:
nms.coordinatesOutputFeatureName = "coordinates"
nms.iouThresholdInputFeatureName = "iouThreshold"
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
nms.iouThreshold = 0.45
nms.confidenceThreshold = 0.25
nms.iouThreshold = self.args.iou
nms.confidenceThreshold = self.args.conf
nms.pickTop.perClass = True
nms.stringClassLabels.vector.extend(names.values())
nms_model = ct.models.MLModel(nms_spec)
@ -1507,3 +1517,91 @@ class IOSDetectModel(torch.nn.Module):
"""Normalize predictions of object detection model with input size-dependent factors."""
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
class NMSModel(torch.nn.Module):
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
def __init__(self, model, args):
"""
Initialize the NMSModel.
Args:
model (torch.nn.module): The model to wrap with NMS postprocessing.
args (Namespace): The export arguments.
"""
super().__init__()
self.model = model
self.args = args
self.obb = model.task == "obb"
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
def forward(self, x):
"""
Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
Args:
x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
Returns:
out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
"""
from functools import partial
from torchvision.ops import nms
preds = self.model(x)
pred = preds[0] if isinstance(preds, tuple) else preds
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
scores, classes = scores.max(dim=-1)
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
out = torch.zeros(
boxes.shape[0],
self.args.max_det,
boxes.shape[-1] + 2 + extra_shape,
device=boxes.device,
dtype=boxes.dtype,
)
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
mask = score > self.args.conf
if self.is_tf:
# TFLite GatherND error if mask is empty
score *= mask
# Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
mask = score.topk(self.args.max_det * 5).indices
box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
if not self.obb:
box = xywh2xyxy(box)
if self.is_tf:
# TFlite bug returns less boxes
box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
nmsbox = box.clone()
# `8` is the minimum value experimented to get correct NMS results for obb
multiplier = 8 if self.obb else 1
# Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
if self.args.format == "tflite": # TFLite is already normalized
nmsbox *= multiplier
else:
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
if not self.args.agnostic_nms: # class-specific NMS
end = 2 if self.obb else 4
# fully explicit expansion otherwise reshape error
# large max_wh causes issues when quantizing
cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
offbox = nmsbox[:, :end] + cls_offset * multiplier
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
nms_fn = (
partial(nms_rotated, use_triu=not (self.is_tf or (self.args.opset or 14) < 14)) if self.obb else nms
)
keep = nms_fn(
torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
score,
self.args.iou,
)[: self.args.max_det]
dets = torch.cat([box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1), extra[keep]], dim=-1)
# Zero-pad to max_det size to avoid reshape error
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
out[i] = torch.nn.functional.pad(dets, pad)
return (out, preds[1]) if self.model.task == "segment" else out