Tests and docstrings improvements (#4475)

This commit is contained in:
Glenn Jocher 2023-08-21 17:02:14 +02:00 committed by GitHub
parent c659c0fa7b
commit 615ddc9d97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 107 additions and 186 deletions

View file

@ -29,7 +29,7 @@ class Sam(nn.Module):
"""
SAM predicts object masks from an image and input prompts.
Arguments:
Args:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
@ -60,14 +60,12 @@ class Sam(nn.Module):
multimask_output: bool,
) -> List[Dict[str, torch.Tensor]]:
"""
Predicts masks end-to-end from provided images and prompts.
If prompts are not known in advance, using SamPredictor is
recommended over calling the model directly.
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
SamPredictor is recommended over calling the model directly.
Arguments:
batched_input (list(dict)): A list over input images, each a
dictionary with the following keys. A prompt key can be
excluded if it is not present.
Args:
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
key can be excluded if it is not present.
'image': The image as a torch tensor in 3xHxW format,
already transformed for input to the model.
'original_size': (tuple(int, int)) The original size of
@ -81,12 +79,11 @@ class Sam(nn.Module):
Already transformed to the input frame of the model.
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
in the form Bx1xHxW.
multimask_output (bool): Whether the model should predict multiple
disambiguating masks, or return a single mask.
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
mask.
Returns:
(list(dict)): A list over input images, where each element is
as dictionary with the following keys.
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
'masks': (torch.Tensor) Batched binary mask predictions,
with shape BxCxHxW, where B is the number of input prompts,
C is determined by multimask_output, and (H, W) is the
@ -139,7 +136,7 @@ class Sam(nn.Module):
"""
Remove padding and upscale masks to the original image size.
Arguments:
Args:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
@ -158,8 +155,7 @@ class Sam(nn.Module):
align_corners=False,
)
masks = masks[..., :input_size[0], :input_size[1]]
masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
return masks
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""