Add utils.ops and nn.modules to tests (#4484)
This commit is contained in:
parent
1cec0185a1
commit
6da8f7f51e
14 changed files with 246 additions and 330 deletions
|
|
@ -30,11 +30,10 @@ class Sam(nn.Module):
|
|||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings that allow for
|
||||
efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
||||
and encoded prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
||||
"""
|
||||
|
|
@ -65,34 +64,25 @@ class Sam(nn.Module):
|
|||
|
||||
Args:
|
||||
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
||||
key can be excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
the image before transformation, as (H, W).
|
||||
'point_coords': (torch.Tensor) Batched point prompts for
|
||||
this image, with shape BxNx2. Already transformed to the
|
||||
input frame of the model.
|
||||
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
||||
with shape BxN.
|
||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
||||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
key can be excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format, already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of the image before transformation, as (H, W).
|
||||
'point_coords': (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already
|
||||
transformed to the input frame of the model.
|
||||
'point_labels': (torch.Tensor) Batched labels for point prompts, with shape BxN.
|
||||
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. Already transformed to the input frame of
|
||||
the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
original size of the image.
|
||||
'iou_predictions': (torch.Tensor) The model's predictions
|
||||
of mask quality, in shape BxC.
|
||||
'low_res_logits': (torch.Tensor) Low resolution logits with
|
||||
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
||||
to subsequent iterations of prediction.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions, with shape BxCxHxW, where B is the number of
|
||||
input prompts, C is determined by multimask_output, and (H, W) is the original size of the image.
|
||||
'iou_predictions': (torch.Tensor) The model's predictions of mask quality, in shape BxC.
|
||||
'low_res_logits': (torch.Tensor) Low resolution logits with shape BxCxHxW, where H=W=256. Can be passed
|
||||
as mask input to subsequent iterations of prediction.
|
||||
"""
|
||||
input_images = torch.stack([self.preprocess(x['image']) for x in batched_input], dim=0)
|
||||
image_embeddings = self.image_encoder(input_images)
|
||||
|
|
@ -137,16 +127,12 @@ class Sam(nn.Module):
|
|||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
model, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original size of the image
|
||||
before resizing for input to the model, in (H, W) format.
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder, in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the model input image, in (H, W) format. Used to remove padding.
|
||||
original_size (tuple(int, int)): The original image size before resizing for input to the model, in (H, W).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
||||
is given by original_size.
|
||||
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) is given by original_size.
|
||||
"""
|
||||
masks = F.interpolate(
|
||||
masks,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue