ultralytics 8.3.13 SAM prompt-inference refactor (#16894)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5a9ce863e4
commit
15e6133534
4 changed files with 85 additions and 54 deletions
|
|
@ -142,11 +142,20 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
|
|||
# Display model information (optional)
|
||||
model.info()
|
||||
|
||||
# Segment with bounding box prompt
|
||||
# Run inference with bboxes prompt
|
||||
results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
|
||||
|
||||
# Segment with point prompt
|
||||
results = model("path/to/image.jpg", points=[150, 150], labels=[1])
|
||||
# Run inference with single point
|
||||
results = model(points=[900, 370], labels=[1])
|
||||
|
||||
# Run inference with multiple points
|
||||
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Run inference with multiple points prompt per object
|
||||
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Run inference with negative points prompt
|
||||
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
#### Segment Everything
|
||||
|
|
|
|||
|
|
@ -59,16 +59,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
|||
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Run inference with single point
|
||||
results = predictor(points=[900, 370], labels=[1])
|
||||
results = model(points=[900, 370], labels=[1])
|
||||
|
||||
# Run inference with multiple points
|
||||
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||
|
||||
# Run inference with multiple points prompt per object
|
||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||
|
||||
# Run inference with negative points prompt
|
||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||
```
|
||||
|
||||
!!! example "Segment everything"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.12"
|
||||
__version__ = "8.3.13"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -235,7 +235,42 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
image_embeddings=features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Prepares and transforms the input prompts for processing based on the destination shape.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple): The target shape (height, width) for the prompts.
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||
"""
|
||||
src_shape = self.batch[1][0].shape[:2]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
|
|
@ -258,23 +293,7 @@ class Predictor(BasePredictor):
|
|||
bboxes *= r
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
# Embed prompts
|
||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||
|
||||
# Predict masks
|
||||
pred_masks, pred_scores = self.model.mask_decoder(
|
||||
image_embeddings=features,
|
||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
return bboxes, points, labels, masks
|
||||
|
||||
def generate(
|
||||
self,
|
||||
|
|
@ -693,34 +712,7 @@ class SAM2Predictor(Predictor):
|
|||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes = bboxes.view(-1, 2, 2) * r
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
|
|
@ -744,6 +736,36 @@ class SAM2Predictor(Predictor):
|
|||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Prepares and transforms the input prompts for processing based on the destination shape.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple): The target shape (height, width) for the prompts.
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||
"""
|
||||
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
||||
if bboxes is not None:
|
||||
bboxes = bboxes.view(-1, 2, 2)
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
return bboxes, points, labels, masks
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocesses and sets a single image for inference using the SAM2 model.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue