ultralytics 8.2.72 SAM 2 multiple-bboxes support (#14928)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
2187649f99
commit
bea4c93278
2 changed files with 11 additions and 16 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.71"
|
||||
__version__ = "8.2.72"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -102,28 +102,23 @@ class SAM2Predictor(Predictor):
|
|||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes *= r
|
||||
bboxes = bboxes.view(-1, 2, 2) * r
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# TODO: Embed prompts
|
||||
# if bboxes is not None:
|
||||
# box_coords = bboxes.reshape(-1, 2, 2)
|
||||
# box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device)
|
||||
# box_labels = box_labels.repeat(bboxes.size(0), 1)
|
||||
# # we merge "boxes" and "points" into a single "concat_points" input (where
|
||||
# # boxes are added at the beginning) to sam_prompt_encoder
|
||||
# if concat_points is not None:
|
||||
# concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
||||
# concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
||||
# concat_points = (concat_coords, concat_labels)
|
||||
# else:
|
||||
# concat_points = (box_coords, box_labels)
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
points=points,
|
||||
boxes=bboxes,
|
||||
boxes=None,
|
||||
masks=masks,
|
||||
)
|
||||
# Predict masks
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue