From 0727b43ca6f754e094f7206c4aefd96d208b1c2c Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Mon, 25 Nov 2024 23:39:14 +0800 Subject: [PATCH] Concat all segments by default for multi-part masks (#16826) Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/utils/ops.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index b76168f9..07b54b3e 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -783,23 +783,29 @@ def regularize_rboxes(rboxes): return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes -def masks2segments(masks, strategy="largest"): +def masks2segments(masks, strategy="all"): """ It takes a list of masks(n,h,w) and returns a list of segments(n,xy). Args: masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) - strategy (str): 'concat' or 'largest'. Defaults to largest + strategy (str): 'all' or 'largest'. Defaults to all Returns: segments (List): list of segment masks """ + from ultralytics.data.converter import merge_multi_segment + segments = [] for x in masks.int().cpu().numpy().astype("uint8"): c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] if c: - if strategy == "concat": # concatenate all segments - c = np.concatenate([x.reshape(-1, 2) for x in c]) + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) elif strategy == "largest": # select largest segment c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) else: