Return boxes for SAM prompts inference (#16276)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
c2068df9d9
commit
02e995383d
2 changed files with 8 additions and 6 deletions
|
|
@ -101,7 +101,7 @@ def test_mobilesam():
|
|||
model.predict(source, points=[900, 370], labels=[1])
|
||||
|
||||
# Predict a segment based on a box prompt
|
||||
model.predict(source, bboxes=[439, 437, 524, 709])
|
||||
model.predict(source, bboxes=[439, 437, 524, 709], save=True)
|
||||
|
||||
# Predict all
|
||||
# model(source)
|
||||
|
|
|
|||
|
|
@ -450,16 +450,18 @@ class Predictor(BasePredictor):
|
|||
|
||||
results = []
|
||||
for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]):
|
||||
if pred_bboxes is not None:
|
||||
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
||||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
||||
|
||||
if len(masks) == 0:
|
||||
masks = None
|
||||
else:
|
||||
masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0]
|
||||
masks = masks > self.model.mask_threshold # to bool
|
||||
if pred_bboxes is not None:
|
||||
pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False)
|
||||
else:
|
||||
pred_bboxes = batched_mask_to_box(masks)
|
||||
# NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency.
|
||||
cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device)
|
||||
pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1)
|
||||
results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes))
|
||||
# Reset segment-all mode.
|
||||
self.segment_all = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue