ultralytics 8.3.67 NMS Export for Detect, Segment, Pose and OBB YOLO models (#18484)

Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Mohammed Yasin 2025-01-24 18:00:36 +08:00 committed by GitHub
parent 0e48a00303
commit 9181ff62f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 320 additions and 208 deletions

View file

@ -143,7 +143,7 @@ def make_divisible(x, divisor):
return math.ceil(x / divisor) * divisor
def nms_rotated(boxes, scores, threshold=0.45):
def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
"""
NMS for oriented bounding boxes using probiou and fast-nms.
@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
scores (torch.Tensor): Confidence scores, shape (N,).
threshold (float, optional): IoU threshold. Defaults to 0.45.
use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
when exporting obb models to some formats that do not support `torch.triu`.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
"""
if len(boxes) == 0:
return np.empty((0,), dtype=np.int8)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
ious = batch_probiou(boxes, boxes)
if use_triu:
ious = ious.triu_(diagonal=1)
# pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
else:
n = boxes.shape[0]
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
upper_mask = row_idx < col_idx
ious = ious * upper_mask
# Zeroing these scores ensures the additional indices would not affect the final results
scores[~((ious >= threshold).sum(0) <= 0)] = 0
# NOTE: return indices with fixed length to avoid TFLite reshape error
pick = torch.topk(scores, scores.shape[0]).indices
return sorted_idx[pick]
@ -179,6 +193,7 @@ def non_max_suppression(
max_wh=7680,
in_place=True,
rotated=False,
end2end=False,
):
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
@ -205,6 +220,7 @@ def non_max_suppression(
max_wh (int): The maximum box width and height in pixels.
in_place (bool): If True, the input prediction tensor will be modified in place.
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
end2end (bool): If the model doesn't require NMS.
Returns:
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
@ -221,7 +237,7 @@ def non_max_suppression(
if classes is not None:
classes = torch.tensor(classes, device=prediction.device)
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
if classes is not None:
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]