ultralytics 8.3.13 SAM prompt-inference refactor (#16894)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
5a9ce863e4
commit
15e6133534
4 changed files with 85 additions and 54 deletions
|
|
@ -142,11 +142,20 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
|
||||||
# Display model information (optional)
|
# Display model information (optional)
|
||||||
model.info()
|
model.info()
|
||||||
|
|
||||||
# Segment with bounding box prompt
|
# Run inference with bboxes prompt
|
||||||
results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
|
results = model("path/to/image.jpg", bboxes=[100, 100, 200, 200])
|
||||||
|
|
||||||
# Segment with point prompt
|
# Run inference with single point
|
||||||
results = model("path/to/image.jpg", points=[150, 150], labels=[1])
|
results = model(points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Run inference with multiple points
|
||||||
|
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||||
|
|
||||||
|
# Run inference with multiple points prompt per object
|
||||||
|
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Run inference with negative points prompt
|
||||||
|
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Segment Everything
|
#### Segment Everything
|
||||||
|
|
|
||||||
|
|
@ -59,16 +59,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
||||||
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||||
|
|
||||||
# Run inference with single point
|
# Run inference with single point
|
||||||
results = predictor(points=[900, 370], labels=[1])
|
results = model(points=[900, 370], labels=[1])
|
||||||
|
|
||||||
# Run inference with multiple points
|
# Run inference with multiple points
|
||||||
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
|
results = model(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||||
|
|
||||||
# Run inference with multiple points prompt per object
|
# Run inference with multiple points prompt per object
|
||||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
# Run inference with negative points prompt
|
# Run inference with negative points prompt
|
||||||
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
results = model(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! example "Segment everything"
|
!!! example "Segment everything"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.12"
|
__version__ = "8.3.13"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,42 @@ class Predictor(BasePredictor):
|
||||||
"""
|
"""
|
||||||
features = self.get_im_features(im) if self.features is None else self.features
|
features = self.get_im_features(im) if self.features is None else self.features
|
||||||
|
|
||||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||||
|
points = (points, labels) if points is not None else None
|
||||||
|
# Embed prompts
|
||||||
|
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
||||||
|
|
||||||
|
# Predict masks
|
||||||
|
pred_masks, pred_scores = self.model.mask_decoder(
|
||||||
|
image_embeddings=features,
|
||||||
|
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
||||||
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
dense_prompt_embeddings=dense_embeddings,
|
||||||
|
multimask_output=multimask_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||||
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||||
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||||
|
|
||||||
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||||
|
"""
|
||||||
|
Prepares and transforms the input prompts for processing based on the destination shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_shape (tuple): The target shape (height, width) for the prompts.
|
||||||
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||||
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||||
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||||
|
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||||
|
"""
|
||||||
|
src_shape = self.batch[1][0].shape[:2]
|
||||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||||
# Transform input prompts
|
# Transform input prompts
|
||||||
if points is not None:
|
if points is not None:
|
||||||
|
|
@ -258,23 +293,7 @@ class Predictor(BasePredictor):
|
||||||
bboxes *= r
|
bboxes *= r
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||||
|
return bboxes, points, labels, masks
|
||||||
points = (points, labels) if points is not None else None
|
|
||||||
# Embed prompts
|
|
||||||
sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks)
|
|
||||||
|
|
||||||
# Predict masks
|
|
||||||
pred_masks, pred_scores = self.model.mask_decoder(
|
|
||||||
image_embeddings=features,
|
|
||||||
image_pe=self.model.prompt_encoder.get_dense_pe(),
|
|
||||||
sparse_prompt_embeddings=sparse_embeddings,
|
|
||||||
dense_prompt_embeddings=dense_embeddings,
|
|
||||||
multimask_output=multimask_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
|
||||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
|
||||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
|
|
@ -693,34 +712,7 @@ class SAM2Predictor(Predictor):
|
||||||
"""
|
"""
|
||||||
features = self.get_im_features(im) if self.features is None else self.features
|
features = self.get_im_features(im) if self.features is None else self.features
|
||||||
|
|
||||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
|
||||||
# Transform input prompts
|
|
||||||
if points is not None:
|
|
||||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
|
||||||
points = points[None] if points.ndim == 1 else points
|
|
||||||
# Assuming labels are all positive if users don't pass labels.
|
|
||||||
if labels is None:
|
|
||||||
labels = torch.ones(points.shape[0])
|
|
||||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
|
||||||
points *= r
|
|
||||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
|
||||||
points, labels = points[:, None], labels[:, None]
|
|
||||||
if bboxes is not None:
|
|
||||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
|
||||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
|
||||||
bboxes = bboxes.view(-1, 2, 2) * r
|
|
||||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
|
||||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
|
||||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
|
||||||
if points is not None:
|
|
||||||
points = torch.cat([bboxes, points], dim=1)
|
|
||||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
|
||||||
else:
|
|
||||||
points, labels = bboxes, bbox_labels
|
|
||||||
if masks is not None:
|
|
||||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
|
||||||
|
|
||||||
points = (points, labels) if points is not None else None
|
points = (points, labels) if points is not None else None
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||||
|
|
@ -744,6 +736,36 @@ class SAM2Predictor(Predictor):
|
||||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||||
|
|
||||||
|
def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None):
|
||||||
|
"""
|
||||||
|
Prepares and transforms the input prompts for processing based on the destination shape.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dst_shape (tuple): The target shape (height, width) for the prompts.
|
||||||
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||||
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||||
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||||
|
masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||||
|
"""
|
||||||
|
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
||||||
|
if bboxes is not None:
|
||||||
|
bboxes = bboxes.view(-1, 2, 2)
|
||||||
|
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||||
|
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||||
|
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||||
|
if points is not None:
|
||||||
|
points = torch.cat([bboxes, points], dim=1)
|
||||||
|
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||||
|
else:
|
||||||
|
points, labels = bboxes, bbox_labels
|
||||||
|
return bboxes, points, labels, masks
|
||||||
|
|
||||||
def set_image(self, image):
|
def set_image(self, image):
|
||||||
"""
|
"""
|
||||||
Preprocesses and sets a single image for inference using the SAM2 model.
|
Preprocesses and sets a single image for inference using the SAM2 model.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue