From d6edf1c1d8b67b7ff3829609f29bcacb4f107c29 Mon Sep 17 00:00:00 2001 From: AlainSchoebi <44315825+AlainSchoebi@users.noreply.github.com> Date: Fri, 23 Feb 2024 14:28:25 +0100 Subject: [PATCH] Add Non-Maximum Suppression (NMS) `inplace` flag (#8368) --- ultralytics/utils/ops.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index d1bfc6a8..4a835160 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -173,6 +173,7 @@ def non_max_suppression( max_time_img=0.05, max_nms=30000, max_wh=7680, + in_place=True, rotated=False, ): """ @@ -197,7 +198,8 @@ def non_max_suppression( nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. max_time_img (float): The maximum time (seconds) for processing one image. max_nms (int): The maximum number of boxes into torchvision.ops.nms(). - max_wh (int): The maximum box width and height in pixels + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. Returns: (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of @@ -224,7 +226,10 @@ def non_max_suppression( prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) if not rotated: - prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy t = time.time() output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs