ultralytics 8.3.12 SAM and SAM2 multi-point prompts (#16643)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Mohammed Yasin 2024-10-13 23:20:40 +08:00 committed by GitHub
parent 1c4d788aa1
commit b89d6f4070
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 79 additions and 13 deletions

View file

@ -213,11 +213,14 @@ class Predictor(BasePredictor):
Args:
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
Raises:
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
(tuple): Tuple containing:
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
@ -240,11 +243,15 @@ class Predictor(BasePredictor):
points = points[None] if points.ndim == 1 else points
# Assuming labels are all positive if users don't pass labels.
if labels is None:
labels = np.ones(points.shape[0])
labels = np.ones(points.shape[:-1])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
assert (
points.shape[-2] == labels.shape[-1]
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
points *= r
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if points.ndim == 2:
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
points, labels = points[:, None, :], labels[:, None]
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes