Concat all segments by default for multi-part masks (#16826)

Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2024-11-25 23:39:14 +08:00 committed by GitHub
parent c5fd0bb48a
commit 0727b43ca6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -783,23 +783,29 @@ def regularize_rboxes(rboxes):
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
def masks2segments(masks, strategy="largest"):
def masks2segments(masks, strategy="all"):
"""
It takes a list of masks(n,h,w) and returns a list of segments(n,xy).
Args:
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
strategy (str): 'concat' or 'largest'. Defaults to largest
strategy (str): 'all' or 'largest'. Defaults to all
Returns:
segments (List): list of segment masks
"""
from ultralytics.data.converter import merge_multi_segment
segments = []
for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
if strategy == "concat": # concatenate all segments
c = np.concatenate([x.reshape(-1, 2) for x in c])
if strategy == "all": # merge and concatenate all segments
c = (
np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
if len(c) > 1
else c[0].reshape(-1, 2)
)
elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else: