Implement all missing docstrings (#5298)

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-10-10 20:07:13 +02:00 committed by GitHub
parent e7f0658744
commit 7fd5dcbd86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 649 additions and 79 deletions

View file

@ -12,6 +12,18 @@ from ultralytics.nn.modules import LayerNorm2d, MLPBlock
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT(nn.Module):
"""
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
The encoded patches are then processed through a neck to generate the final encoded representation.
Attributes:
img_size (int): Dimension of input images, assumed to be square.
patch_embed (PatchEmbed): Module for patch embedding.
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
neck (nn.Sequential): Neck module to further process the output.
"""
def __init__(
self,
@ -112,6 +124,22 @@ class ImageEncoderViT(nn.Module):
class PromptEncoder(nn.Module):
"""
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
produces both sparse and dense embeddings for the input prompts.
Attributes:
embed_dim (int): Dimension of the embeddings.
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
num_point_embeddings (int): Number of point embeddings for different types of points.
point_embeddings (nn.ModuleList): List of point embeddings.
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
mask_input_size (Tuple[int, int]): Size of the input mask.
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
"""
def __init__(
self,