Return boxes for SAM prompts inference (#16276)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Laughing 2024-09-14 18:11:46 +08:00 committed by GitHub
parent c2068df9d9
commit 02e995383d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 6 deletions

View file

@ -101,7 +101,7 @@ def test_mobilesam():
model.predict(source, points=[900, 370], labels=[1]) model.predict(source, points=[900, 370], labels=[1])
# Predict a segment based on a box prompt # Predict a segment based on a box prompt
model.predict(source, bboxes=[439, 437, 524, 709]) model.predict(source, bboxes=[439, 437, 524, 709], save=True)
# Predict all # Predict all
# model(source) # model(source)

View file

@ -450,16 +450,18 @@ class Predictor(BasePredictor):
results = [] results = []
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
if len(masks) == 0: if len(masks) == 0:
masks = None masks = None
else: else:
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
masks = masks > self.model.mask_threshold # to bool masks = masks > self.model.mask_threshold # to bool
if pred_bboxes is not None:
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
else:
pred_bboxes = batched_mask_to_box(masks)
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
# Reset segment-all mode. # Reset segment-all mode.
self.segment_all = False self.segment_all = False