ultralytics 8.3.67 NMS Export for Detect, Segment, Pose and OBB YOLO models (#18484)
Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
0e48a00303
commit
9181ff62f5
17 changed files with 320 additions and 208 deletions
|
|
@ -103,7 +103,7 @@ from ultralytics.utils.checks import (
|
|||
)
|
||||
from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download
|
||||
from ultralytics.utils.files import file_size, spaces_in_path
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.ops import Profile, nms_rotated, xywh2xyxy
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device
|
||||
|
||||
|
||||
|
|
@ -111,16 +111,16 @@ def export_formats():
|
|||
"""Ultralytics YOLO export formats."""
|
||||
x = [
|
||||
["PyTorch", "-", ".pt", True, True, []],
|
||||
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]],
|
||||
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]],
|
||||
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]],
|
||||
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]],
|
||||
["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize", "nms"]],
|
||||
["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify", "nms"]],
|
||||
["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8", "nms"]],
|
||||
["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify", "nms"]],
|
||||
["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]],
|
||||
["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras", "nms"]],
|
||||
["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]],
|
||||
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]],
|
||||
["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []],
|
||||
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]],
|
||||
["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8", "nms"]],
|
||||
["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]],
|
||||
["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]],
|
||||
["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]],
|
||||
|
|
@ -281,6 +281,11 @@ class Exporter:
|
|||
)
|
||||
if self.args.int8 and tflite:
|
||||
assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models."
|
||||
if self.args.nms:
|
||||
if getattr(model, "end2end", False):
|
||||
LOGGER.warning("WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
||||
self.args.nms = False
|
||||
self.args.conf = self.args.conf or 0.25 # set conf default value for nms export
|
||||
if edgetpu:
|
||||
if not LINUX:
|
||||
raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler")
|
||||
|
|
@ -344,8 +349,8 @@ class Exporter:
|
|||
)
|
||||
|
||||
y = None
|
||||
for _ in range(2):
|
||||
y = model(im) # dry runs
|
||||
for _ in range(2): # dry runs
|
||||
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
|
||||
if self.args.half and onnx and self.device.type != "cpu":
|
||||
im, model = im.half(), model.half() # to FP16
|
||||
|
||||
|
|
@ -476,7 +481,7 @@ class Exporter:
|
|||
LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...")
|
||||
f = self.file.with_suffix(".torchscript")
|
||||
|
||||
ts = torch.jit.trace(self.model, self.im, strict=False)
|
||||
ts = torch.jit.trace(NMSModel(self.model, self.args) if self.args.nms else self.model, self.im, strict=False)
|
||||
extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
|
||||
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
|
||||
LOGGER.info(f"{prefix} optimizing for mobile...")
|
||||
|
|
@ -499,7 +504,6 @@ class Exporter:
|
|||
opset_version = self.args.opset or get_latest_opset()
|
||||
LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...")
|
||||
f = str(self.file.with_suffix(".onnx"))
|
||||
|
||||
output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"]
|
||||
dynamic = self.args.dynamic
|
||||
if dynamic:
|
||||
|
|
@ -509,9 +513,18 @@ class Exporter:
|
|||
dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160)
|
||||
elif isinstance(self.model, DetectionModel):
|
||||
dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400)
|
||||
if self.args.nms: # only batch size is dynamic with NMS
|
||||
dynamic["output0"].pop(2)
|
||||
if self.args.nms and self.model.task == "obb":
|
||||
self.args.opset = opset_version # for NMSModel
|
||||
# OBB error https://github.com/pytorch/pytorch/issues/110859#issuecomment-1757841865
|
||||
torch.onnx.register_custom_op_symbolic("aten::lift_fresh", lambda g, x: x, opset_version)
|
||||
check_requirements("onnxslim>=0.1.46") # Older versions has bug with OBB
|
||||
|
||||
torch.onnx.export(
|
||||
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu
|
||||
NMSModel(self.model.cpu() if dynamic else self.model, self.args)
|
||||
if self.args.nms
|
||||
else self.model, # dynamic=True only compatible with cpu
|
||||
self.im.cpu() if dynamic else self.im,
|
||||
f,
|
||||
verbose=False,
|
||||
|
|
@ -553,7 +566,7 @@ class Exporter:
|
|||
LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...")
|
||||
assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed"
|
||||
ov_model = ov.convert_model(
|
||||
self.model,
|
||||
NMSModel(self.model, self.args) if self.args.nms else self.model,
|
||||
input=None if self.args.dynamic else [self.im.shape],
|
||||
example_input=self.im,
|
||||
)
|
||||
|
|
@ -736,9 +749,6 @@ class Exporter:
|
|||
f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage")
|
||||
if f.is_dir():
|
||||
shutil.rmtree(f)
|
||||
if self.args.nms and getattr(self.model, "end2end", False):
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.")
|
||||
self.args.nms = False
|
||||
|
||||
bias = [0.0, 0.0, 0.0]
|
||||
scale = 1 / 255
|
||||
|
|
@ -1438,8 +1448,8 @@ class Exporter:
|
|||
nms.coordinatesOutputFeatureName = "coordinates"
|
||||
nms.iouThresholdInputFeatureName = "iouThreshold"
|
||||
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
|
||||
nms.iouThreshold = 0.45
|
||||
nms.confidenceThreshold = 0.25
|
||||
nms.iouThreshold = self.args.iou
|
||||
nms.confidenceThreshold = self.args.conf
|
||||
nms.pickTop.perClass = True
|
||||
nms.stringClassLabels.vector.extend(names.values())
|
||||
nms_model = ct.models.MLModel(nms_spec)
|
||||
|
|
@ -1507,3 +1517,91 @@ class IOSDetectModel(torch.nn.Module):
|
|||
"""Normalize predictions of object detection model with input size-dependent factors."""
|
||||
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
|
||||
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
|
||||
|
||||
|
||||
class NMSModel(torch.nn.Module):
|
||||
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
|
||||
|
||||
def __init__(self, model, args):
|
||||
"""
|
||||
Initialize the NMSModel.
|
||||
|
||||
Args:
|
||||
model (torch.nn.module): The model to wrap with NMS postprocessing.
|
||||
args (Namespace): The export arguments.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.args = args
|
||||
self.obb = model.task == "obb"
|
||||
self.is_tf = self.args.format in frozenset({"saved_model", "tflite", "tfjs"})
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Performs inference with NMS post-processing. Supports Detect, Segment, OBB and Pose.
|
||||
|
||||
Args:
|
||||
x (torch.tensor): The preprocessed tensor with shape (N, 3, H, W).
|
||||
|
||||
Returns:
|
||||
out (torch.tensor): The post-processed results with shape (N, max_det, 4 + 2 + extra_shape).
|
||||
"""
|
||||
from functools import partial
|
||||
|
||||
from torchvision.ops import nms
|
||||
|
||||
preds = self.model(x)
|
||||
pred = preds[0] if isinstance(preds, tuple) else preds
|
||||
pred = pred.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
extra_shape = pred.shape[-1] - (4 + self.model.nc) # extras from Segment, OBB, Pose
|
||||
boxes, scores, extras = pred.split([4, self.model.nc, extra_shape], dim=2)
|
||||
scores, classes = scores.max(dim=-1)
|
||||
# (N, max_det, 4 coords + 1 class score + 1 class label + extra_shape).
|
||||
out = torch.zeros(
|
||||
boxes.shape[0],
|
||||
self.args.max_det,
|
||||
boxes.shape[-1] + 2 + extra_shape,
|
||||
device=boxes.device,
|
||||
dtype=boxes.dtype,
|
||||
)
|
||||
for i, (box, cls, score, extra) in enumerate(zip(boxes, classes, scores, extras)):
|
||||
mask = score > self.args.conf
|
||||
if self.is_tf:
|
||||
# TFLite GatherND error if mask is empty
|
||||
score *= mask
|
||||
# Explicit length otherwise reshape error, hardcoded to `self.args.max_det * 5`
|
||||
mask = score.topk(self.args.max_det * 5).indices
|
||||
box, score, cls, extra = box[mask], score[mask], cls[mask], extra[mask]
|
||||
if not self.obb:
|
||||
box = xywh2xyxy(box)
|
||||
if self.is_tf:
|
||||
# TFlite bug returns less boxes
|
||||
box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
|
||||
nmsbox = box.clone()
|
||||
# `8` is the minimum value experimented to get correct NMS results for obb
|
||||
multiplier = 8 if self.obb else 1
|
||||
# Normalize boxes for NMS since large values for class offset causes issue with int8 quantization
|
||||
if self.args.format == "tflite": # TFLite is already normalized
|
||||
nmsbox *= multiplier
|
||||
else:
|
||||
nmsbox = multiplier * nmsbox / torch.tensor(x.shape[2:], device=box.device, dtype=box.dtype).max()
|
||||
if not self.args.agnostic_nms: # class-specific NMS
|
||||
end = 2 if self.obb else 4
|
||||
# fully explicit expansion otherwise reshape error
|
||||
# large max_wh causes issues when quantizing
|
||||
cls_offset = cls.reshape(-1, 1).expand(nmsbox.shape[0], end)
|
||||
offbox = nmsbox[:, :end] + cls_offset * multiplier
|
||||
nmsbox = torch.cat((offbox, nmsbox[:, end:]), dim=-1)
|
||||
nms_fn = (
|
||||
partial(nms_rotated, use_triu=not (self.is_tf or (self.args.opset or 14) < 14)) if self.obb else nms
|
||||
)
|
||||
keep = nms_fn(
|
||||
torch.cat([nmsbox, extra], dim=-1) if self.obb else nmsbox,
|
||||
score,
|
||||
self.args.iou,
|
||||
)[: self.args.max_det]
|
||||
dets = torch.cat([box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1), extra[keep]], dim=-1)
|
||||
# Zero-pad to max_det size to avoid reshape error
|
||||
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
|
||||
out[i] = torch.nn.functional.pad(dets, pad)
|
||||
return (out, preds[1]) if self.model.task == "segment" else out
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class Results(SimpleClass):
|
|||
if v is not None:
|
||||
return len(v)
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None):
|
||||
def update(self, boxes=None, masks=None, probs=None, obb=None, keypoints=None):
|
||||
"""
|
||||
Updates the Results object with new detection data.
|
||||
|
||||
|
|
@ -318,6 +318,7 @@ class Results(SimpleClass):
|
|||
masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks.
|
||||
probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities.
|
||||
obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates.
|
||||
keypoints (torch.Tensor | None): A tensor of shape (N, 17, 3) containing keypoints.
|
||||
|
||||
Examples:
|
||||
>>> results = model("image.jpg")
|
||||
|
|
@ -332,6 +333,8 @@ class Results(SimpleClass):
|
|||
self.probs = probs
|
||||
if obb is not None:
|
||||
self.obb = OBB(obb, self.orig_shape)
|
||||
if keypoints is not None:
|
||||
self.keypoints = Keypoints(keypoints, self.orig_shape)
|
||||
|
||||
def _apply(self, fn, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue