ultralytics 8.3.12 SAM and SAM2 multi-point prompts (#16643)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1c4d788aa1
commit
b89d6f4070
6 changed files with 79 additions and 13 deletions
|
|
@ -213,11 +213,14 @@ class Predictor(BasePredictor):
|
|||
Args:
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background.
|
||||
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
|
|
@ -240,11 +243,15 @@ class Predictor(BasePredictor):
|
|||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = np.ones(points.shape[0])
|
||||
labels = np.ones(points.shape[:-1])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert (
|
||||
points.shape[-2] == labels.shape[-1]
|
||||
), f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}."
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None, :], labels[:, None]
|
||||
if points.ndim == 2:
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None, :], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue