ultralytics 8.3.67 NMS Export for Detect, Segment, Pose and OBB YOLO models (#18484)
Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
0e48a00303
commit
9181ff62f5
17 changed files with 320 additions and 208 deletions
|
|
@ -143,7 +143,7 @@ def make_divisible(x, divisor):
|
|||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def nms_rotated(boxes, scores, threshold=0.45):
|
||||
def nms_rotated(boxes, scores, threshold=0.45, use_triu=True):
|
||||
"""
|
||||
NMS for oriented bounding boxes using probiou and fast-nms.
|
||||
|
||||
|
|
@ -151,16 +151,30 @@ def nms_rotated(boxes, scores, threshold=0.45):
|
|||
boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr.
|
||||
scores (torch.Tensor): Confidence scores, shape (N,).
|
||||
threshold (float, optional): IoU threshold. Defaults to 0.45.
|
||||
use_triu (bool, optional): Whether to use `torch.triu` operator. It'd be useful for disable it
|
||||
when exporting obb models to some formats that do not support `torch.triu`.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
"""
|
||||
if len(boxes) == 0:
|
||||
return np.empty((0,), dtype=np.int8)
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = batch_probiou(boxes, boxes).triu_(diagonal=1)
|
||||
pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
ious = batch_probiou(boxes, boxes)
|
||||
if use_triu:
|
||||
ious = ious.triu_(diagonal=1)
|
||||
# pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1)
|
||||
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
||||
pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1)
|
||||
else:
|
||||
n = boxes.shape[0]
|
||||
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
|
||||
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
|
||||
upper_mask = row_idx < col_idx
|
||||
ious = ious * upper_mask
|
||||
# Zeroing these scores ensures the additional indices would not affect the final results
|
||||
scores[~((ious >= threshold).sum(0) <= 0)] = 0
|
||||
# NOTE: return indices with fixed length to avoid TFLite reshape error
|
||||
pick = torch.topk(scores, scores.shape[0]).indices
|
||||
return sorted_idx[pick]
|
||||
|
||||
|
||||
|
|
@ -179,6 +193,7 @@ def non_max_suppression(
|
|||
max_wh=7680,
|
||||
in_place=True,
|
||||
rotated=False,
|
||||
end2end=False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
|
||||
|
|
@ -205,6 +220,7 @@ def non_max_suppression(
|
|||
max_wh (int): The maximum box width and height in pixels.
|
||||
in_place (bool): If True, the input prediction tensor will be modified in place.
|
||||
rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS.
|
||||
end2end (bool): If the model doesn't require NMS.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
|
||||
|
|
@ -221,7 +237,7 @@ def non_max_suppression(
|
|||
if classes is not None:
|
||||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue