Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -29,7 +29,7 @@ class Sam(nn.Module):
|
|||
"""
|
||||
SAM predicts object masks from an image and input prompts.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the
|
||||
image into image embeddings that allow for efficient mask prediction.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
|
|
@ -60,14 +60,12 @@ class Sam(nn.Module):
|
|||
multimask_output: bool,
|
||||
) -> List[Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Predicts masks end-to-end from provided images and prompts.
|
||||
If prompts are not known in advance, using SamPredictor is
|
||||
recommended over calling the model directly.
|
||||
Predicts masks end-to-end from provided images and prompts. If prompts are not known in advance, using
|
||||
SamPredictor is recommended over calling the model directly.
|
||||
|
||||
Arguments:
|
||||
batched_input (list(dict)): A list over input images, each a
|
||||
dictionary with the following keys. A prompt key can be
|
||||
excluded if it is not present.
|
||||
Args:
|
||||
batched_input (list(dict)): A list over input images, each a dictionary with the following keys. A prompt
|
||||
key can be excluded if it is not present.
|
||||
'image': The image as a torch tensor in 3xHxW format,
|
||||
already transformed for input to the model.
|
||||
'original_size': (tuple(int, int)) The original size of
|
||||
|
|
@ -81,12 +79,11 @@ class Sam(nn.Module):
|
|||
Already transformed to the input frame of the model.
|
||||
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
||||
in the form Bx1xHxW.
|
||||
multimask_output (bool): Whether the model should predict multiple
|
||||
disambiguating masks, or return a single mask.
|
||||
multimask_output (bool): Whether the model should predict multiple disambiguating masks, or return a single
|
||||
mask.
|
||||
|
||||
Returns:
|
||||
(list(dict)): A list over input images, where each element is
|
||||
as dictionary with the following keys.
|
||||
(list(dict)): A list over input images, where each element is as dictionary with the following keys.
|
||||
'masks': (torch.Tensor) Batched binary mask predictions,
|
||||
with shape BxCxHxW, where B is the number of input prompts,
|
||||
C is determined by multimask_output, and (H, W) is the
|
||||
|
|
@ -139,7 +136,7 @@ class Sam(nn.Module):
|
|||
"""
|
||||
Remove padding and upscale masks to the original image size.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
masks (torch.Tensor): Batched masks from the mask_decoder,
|
||||
in BxCxHxW format.
|
||||
input_size (tuple(int, int)): The size of the image input to the
|
||||
|
|
@ -158,8 +155,7 @@ class Sam(nn.Module):
|
|||
align_corners=False,
|
||||
)
|
||||
masks = masks[..., :input_size[0], :input_size[1]]
|
||||
masks = F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
|
||||
return masks
|
||||
return F.interpolate(masks, original_size, mode='bilinear', align_corners=False)
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize pixel values and pad to a square input."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue