YOLO Segment sigmoid() fix (#13939)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: DeepDiver <zhaoxu1015@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-06-25 11:53:42 +02:00 committed by GitHub
parent 3bb0c5afa3
commit b10e0f3fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View file

@ -260,7 +260,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- name: Install requirements
run: pip install -e .
run: pip install -e . pytest-cov
- name: Check environment
run: |
yolo checks

View file

@ -661,10 +661,10 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
(torch.Tensor): The upsampled masks.
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
return masks.gt_(0.0)
def process_mask(protos, masks_in, bboxes, shape, upsample=False):
@ -685,7 +685,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw) # CHW
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
width_ratio = mw / iw
height_ratio = mh / ih
@ -698,7 +698,7 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.5)
return masks.gt_(0.0)
def process_mask_native(protos, masks_in, bboxes, shape):
@ -715,10 +715,10 @@ def process_mask_native(protos, masks_in, bboxes, shape):
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
masks = scale_masks(masks[None], shape)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.5)
return masks.gt_(0.0)
def scale_masks(masks, shape, padding=True):