Cleanup redundant SAM forward() methods (#4591)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-27 19:19:13 +02:00 committed by GitHub
parent 47ab96dab6
commit 2567b288c9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 118 deletions

View file

@ -6,11 +6,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Tuple
from typing import List
import torch
from torch import nn
from torch.nn import functional as F
from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
@ -31,6 +30,9 @@ class Sam(nn.Module):
"""
SAM predicts object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
efficient mask prediction.
@ -45,109 +47,3 @@ class Sam(nn.Module):
self.mask_decoder = mask_decoder
self.register_buffer('pixel_mean', torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer('pixel_std', torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
def device(self) -> Any:
return self.pixel_mean.device
@torch.no_grad()
def forward(
self,
batched_input: List[Dict[str, Any]],
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
SamPredictor is recommended over calling the model directly.
Args:
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
transformed to the input frame of the model.
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
mask.
Returns:
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
as mask input to subsequent iterations of prediction.
"""
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
image_embeddings = self.image_encoder(input_images)
outputs = []
for image_record, curr_embedding in zip(batched_input, image_embeddings):
if 'point_coords' in image_record:
points = (image_record['point_coords'], image_record['point_labels'])
else:
points = None
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=image_record.get('boxes', None),
masks=image_record.get('mask_inputs', None),
)
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
masks = self.postprocess_masks(
low_res_masks,
input_size=image_record['image'].shape[-2:],
original_size=image_record['original_size'],
)
masks = masks > self.mask_threshold
outputs.append({
'masks': masks,
'iou_predictions': iou_predictions,
'low_res_logits': low_res_masks, })
return outputs
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
Args:
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode='bilinear',
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
return F.pad(x, (0, padw, 0, padh))