ultralytics 8.3.12 SAM and SAM2 multi-point prompts (#16643)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1c4d788aa1
commit
b89d6f4070
6 changed files with 79 additions and 13 deletions
|
|
@ -90,8 +90,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
|
||||||
# Load the model
|
# Load the model
|
||||||
model = SAM("mobile_sam.pt")
|
model = SAM("mobile_sam.pt")
|
||||||
|
|
||||||
# Predict a segment based on a point prompt
|
# Predict a segment based on a single point prompt
|
||||||
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Predict multiple segments based on multiple points prompt
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||||
|
|
||||||
|
# Predict a segment based on multiple points prompt per object
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Predict a segment using both positive and negative prompts.
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
### Box Prompt
|
### Box Prompt
|
||||||
|
|
@ -106,8 +115,17 @@ You can download the model [here](https://github.com/ChaoningZhang/MobileSAM/blo
|
||||||
# Load the model
|
# Load the model
|
||||||
model = SAM("mobile_sam.pt")
|
model = SAM("mobile_sam.pt")
|
||||||
|
|
||||||
# Predict a segment based on a box prompt
|
# Predict a segment based on a single point prompt
|
||||||
model.predict("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
model.predict("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Predict mutiple segments based on multiple points prompt
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||||
|
|
||||||
|
# Predict a segment based on multiple points prompt per object
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Predict a segment using both positive and negative prompts.
|
||||||
|
model.predict("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](sam.md).
|
We have implemented `MobileSAM` and `SAM` using the same API. For more usage information, please see the [SAM page](sam.md).
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,17 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
||||||
# Run inference with bboxes prompt
|
# Run inference with bboxes prompt
|
||||||
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
results = model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||||
|
|
||||||
# Run inference with points prompt
|
# Run inference with single point
|
||||||
results = model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
results = predictor(points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Run inference with multiple points
|
||||||
|
results = predictor(points=[[400, 370], [900, 370]], labels=[1, 1])
|
||||||
|
|
||||||
|
# Run inference with multiple points prompt per object
|
||||||
|
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Run inference with negative points prompt
|
||||||
|
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
!!! example "Segment everything"
|
!!! example "Segment everything"
|
||||||
|
|
@ -107,8 +116,16 @@ The Segment Anything Model can be employed for a multitude of downstream tasks t
|
||||||
predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file
|
predictor.set_image("ultralytics/assets/zidane.jpg") # set with image file
|
||||||
predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray
|
predictor.set_image(cv2.imread("ultralytics/assets/zidane.jpg")) # set with np.ndarray
|
||||||
results = predictor(bboxes=[439, 437, 524, 709])
|
results = predictor(bboxes=[439, 437, 524, 709])
|
||||||
|
|
||||||
|
# Run inference with single point prompt
|
||||||
results = predictor(points=[900, 370], labels=[1])
|
results = predictor(points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Run inference with multiple points prompt
|
||||||
|
results = predictor(points=[[400, 370], [900, 370]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Run inference with negative points prompt
|
||||||
|
results = predictor(points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
|
|
||||||
# Reset image
|
# Reset image
|
||||||
predictor.reset_image()
|
predictor.reset_image()
|
||||||
```
|
```
|
||||||
|
|
@ -245,6 +262,15 @@ model("ultralytics/assets/zidane.jpg", bboxes=[439, 437, 524, 709])
|
||||||
|
|
||||||
# Segment with points prompt
|
# Segment with points prompt
|
||||||
model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
model("ultralytics/assets/zidane.jpg", points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Segment with multiple points prompt
|
||||||
|
model("ultralytics/assets/zidane.jpg", points=[[400, 370], [900, 370]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Segment with multiple points prompt per object
|
||||||
|
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Segment with negative points prompt.
|
||||||
|
model("ultralytics/assets/zidane.jpg", points=[[[400, 370], [900, 370]]], labels=[[1, 0]])
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, you can run inference with SAM in the command line interface (CLI):
|
Alternatively, you can run inference with SAM in the command line interface (CLI):
|
||||||
|
|
|
||||||
|
|
@ -97,9 +97,12 @@ def test_mobilesam():
|
||||||
# Source
|
# Source
|
||||||
source = ASSETS / "zidane.jpg"
|
source = ASSETS / "zidane.jpg"
|
||||||
|
|
||||||
# Predict a segment based on a point prompt
|
# Predict a segment based on a 1D point prompt and 1D labels.
|
||||||
model.predict(source, points=[900, 370], labels=[1])
|
model.predict(source, points=[900, 370], labels=[1])
|
||||||
|
|
||||||
|
# Predict a segment based on 3D points and 2D labels (multiple points per object).
|
||||||
|
model.predict(source, points=[[[900, 370], [1000, 100]]], labels=[[1, 1]])
|
||||||
|
|
||||||
# Predict a segment based on a box prompt
|
# Predict a segment based on a box prompt
|
||||||
model.predict(source, bboxes=[439, 437, 524, 709], save=True)
|
model.predict(source, bboxes=[439, 437, 524, 709], save=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,9 +127,21 @@ def test_predict_sam():
|
||||||
# Run inference with bboxes prompt
|
# Run inference with bboxes prompt
|
||||||
model(SOURCE, bboxes=[439, 437, 524, 709], device=0)
|
model(SOURCE, bboxes=[439, 437, 524, 709], device=0)
|
||||||
|
|
||||||
# Run inference with points prompt
|
# Run inference with no labels
|
||||||
|
model(ASSETS / "zidane.jpg", points=[900, 370], device=0)
|
||||||
|
|
||||||
|
# Run inference with 1D points and 1D labels
|
||||||
model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0)
|
model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0)
|
||||||
|
|
||||||
|
# Run inference with 2D points and 1D labels
|
||||||
|
model(ASSETS / "zidane.jpg", points=[[900, 370]], labels=[1], device=0)
|
||||||
|
|
||||||
|
# Run inference with multiple 2D points and 1D labels
|
||||||
|
model(ASSETS / "zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1], device=0)
|
||||||
|
|
||||||
|
# Run inference with 3D points and 2D labels (multiple points per object)
|
||||||
|
model(ASSETS / "zidane.jpg", points=[[[900, 370], [1000, 100]]], labels=[[1, 1]], device=0)
|
||||||
|
|
||||||
# Create SAMPredictor
|
# Create SAMPredictor
|
||||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=WEIGHTS_DIR / "mobile_sam.pt")
|
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=WEIGHTS_DIR / "mobile_sam.pt")
|
||||||
predictor = SAMPredictor(overrides=overrides)
|
predictor = SAMPredictor(overrides=overrides)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.11"
|
__version__ = "8.3.12"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -213,11 +213,14 @@ class Predictor(BasePredictor):
|
||||||
Args:
|
Args:
|
||||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
|
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||||
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
||||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Tuple containing:
|
(tuple): Tuple containing:
|
||||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||||
|
|
@ -240,9 +243,13 @@ class Predictor(BasePredictor):
|
||||||
points = points[None] if points.ndim == 1 else points
|
points = points[None] if points.ndim == 1 else points
|
||||||
# Assuming labels are all positive if users don't pass labels.
|
# Assuming labels are all positive if users don't pass labels.
|
||||||
if labels is None:
|
if labels is None:
|
||||||
labels = np.ones(points.shape[0])
|
labels = np.ones(points.shape[:-1])
|
||||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||||
|
assert (
|
||||||
|
points.shape[-2] == labels.shape[-1]
|
||||||
|
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
|
||||||
points *= r
|
points *= r
|
||||||
|
if points.ndim == 2:
|
||||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||||
points, labels = points[:, None, :], labels[:, None]
|
points, labels = points[:, None, :], labels[:, None]
|
||||||
if bboxes is not None:
|
if bboxes is not None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue