Tests and docstrings improvements (#4475)
This commit is contained in:
parent
c659c0fa7b
commit
615ddc9d97
22 changed files with 107 additions and 186 deletions
|
|
@ -24,7 +24,7 @@ class MaskDecoder(nn.Module):
|
|||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a transformer architecture.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
transformer_dim (int): the channel dimension of the transformer module
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
|
||||
|
|
@ -65,7 +65,7 @@ class MaskDecoder(nn.Module):
|
|||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Arguments:
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue