ultralytics 8.2.73 Meta SAM2 Refactor (#14867)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
bea4c93278
commit
5d9046abda
44 changed files with 4542 additions and 3624 deletions
|
|
@ -4,7 +4,6 @@ from .fastsam import FastSAM
|
|||
from .nas import NAS
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
from .sam2 import SAM2
|
||||
from .yolo import YOLO, YOLOWorld
|
||||
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import SAM
|
||||
from .predict import Predictor
|
||||
from .predict import Predictor, SAM2Predictor
|
||||
|
||||
__all__ = "SAM", "Predictor" # tuple or list
|
||||
__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import torch
|
|||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""Return a boolean tensor indicating if boxes are near the crop edge."""
|
||||
"""Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance."""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
|
|
@ -22,7 +22,7 @@ def is_box_near_crop_edge(
|
|||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
||||
"""Yield batches of data from the input arguments."""
|
||||
"""Yields batches of data from input arguments with specified batch size for efficient processing."""
|
||||
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
|
|
@ -33,12 +33,26 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|||
"""
|
||||
Computes the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between the binary masks obtained by thresholding the predicted mask logits at high
|
||||
and low values.
|
||||
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
||||
high and low values.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Batch of predicted mask logits.
|
||||
mask_threshold (float): Threshold value for creating binary masks.
|
||||
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Stability scores for each mask in the batch.
|
||||
|
||||
Notes:
|
||||
- One mask is always contained inside the other.
|
||||
- Save memory by preventing unnecessary cast to torch.int64
|
||||
- Memory is saved by preventing unnecessary cast to torch.int64.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
||||
>>> mask_threshold = 0.5
|
||||
>>> threshold_offset = 0.1
|
||||
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
||||
"""
|
||||
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
|
|
@ -46,7 +60,7 @@ def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, thresh
|
|||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1]."""
|
||||
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
|
|
@ -55,18 +69,14 @@ def build_point_grid(n_per_side: int) -> np.ndarray:
|
|||
|
||||
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]:
|
||||
"""Generate point grids for all crop layers."""
|
||||
"""Generates point grids for multiple crop layers with varying scales and densities."""
|
||||
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> Tuple[List[List[int]], List[int]]:
|
||||
"""
|
||||
Generates a list of crop boxes of different sizes.
|
||||
|
||||
Each layer has (2**i)**2 boxes for the ith layer.
|
||||
"""
|
||||
"""Generates crop boxes of varying sizes for multi-scale image processing, with layered overlapping regions."""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
|
@ -99,7 +109,7 @@ def generate_crop_boxes(
|
|||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
"""Uncrop bounding boxes by adding the crop box offset."""
|
||||
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
|
|
@ -109,7 +119,7 @@ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
||||
"""Uncrop points by adding the crop box offset."""
|
||||
"""Uncrop points by adding the crop box offset to their coordinates."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
|
|
@ -119,7 +129,7 @@ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
|||
|
||||
|
||||
def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
||||
"""Uncrop masks by padding them to the original image size."""
|
||||
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
|
|
@ -130,7 +140,7 @@ def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w:
|
|||
|
||||
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]:
|
||||
"""Remove small disconnected regions or holes in a mask, returning the mask and a modification indicator."""
|
||||
"""Removes small disconnected regions or holes in a mask based on area threshold and mode."""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
||||
|
|
@ -150,11 +160,7 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
|
|||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculates boxes in XYXY format around masks.
|
||||
|
||||
Return [0,0,0,0] for an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
||||
"""
|
||||
"""Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes."""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
|
|
|||
|
|
@ -13,14 +13,15 @@ import torch
|
|||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
||||
from .modules.sam import SAMModel
|
||||
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
||||
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||||
from .modules.sam import SAM2Model, SAMModel
|
||||
from .modules.tiny_encoder import TinyViT
|
||||
from .modules.transformer import TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) h-size model."""
|
||||
"""Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
|
|
@ -31,7 +32,7 @@ def build_sam_vit_h(checkpoint=None):
|
|||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) l-size model."""
|
||||
"""Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
|
|
@ -42,7 +43,7 @@ def build_sam_vit_l(checkpoint=None):
|
|||
|
||||
|
||||
def build_sam_vit_b(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) b-size model."""
|
||||
"""Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
|
|
@ -53,7 +54,7 @@ def build_sam_vit_b(checkpoint=None):
|
|||
|
||||
|
||||
def build_mobile_sam(checkpoint=None):
|
||||
"""Build and return Mobile Segment Anything Model (Mobile-SAM)."""
|
||||
"""Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=[64, 128, 160, 320],
|
||||
encoder_depth=[2, 2, 6, 2],
|
||||
|
|
@ -64,10 +65,85 @@ def build_mobile_sam(checkpoint=None):
|
|||
)
|
||||
|
||||
|
||||
def build_sam2_t(checkpoint=None):
|
||||
"""Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 7, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[5, 7, 9],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_s(checkpoint=None):
|
||||
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 11, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[7, 10, 13],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_b(checkpoint=None):
|
||||
"""Builds and returns a SAM2 base-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=112,
|
||||
encoder_stages=[2, 3, 16, 3],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[12, 16, 20],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_window_spatial_size=[14, 14],
|
||||
encoder_backbone_channel_list=[896, 448, 224, 112],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_l(checkpoint=None):
|
||||
"""Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=144,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[23, 33, 43],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_sam(
|
||||
encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint=None, mobile_sam=False
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
mobile_sam=False,
|
||||
):
|
||||
"""Builds the selected SAM model architecture."""
|
||||
"""
|
||||
Builds a Segment Anything Model (SAM) with specified encoder parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int | List[int]): Embedding dimension for the encoder.
|
||||
encoder_depth (int | List[int]): Depth of the encoder.
|
||||
encoder_num_heads (int | List[int]): Number of attention heads in the encoder.
|
||||
encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder.
|
||||
checkpoint (str | None): Path to the model checkpoint file.
|
||||
mobile_sam (bool): Whether to build a Mobile-SAM model.
|
||||
|
||||
Returns:
|
||||
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
||||
|
||||
Examples:
|
||||
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
|
||||
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
|
||||
"""
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
|
|
@ -139,16 +215,131 @@ def _build_sam(
|
|||
return sam
|
||||
|
||||
|
||||
def _build_sam2(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[7, 15, 23, 31],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
encoder_window_spatial_size=[7, 7],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
checkpoint=None,
|
||||
):
|
||||
"""
|
||||
Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int): Embedding dimension for the encoder.
|
||||
encoder_stages (List[int]): Number of blocks in each stage of the encoder.
|
||||
encoder_num_heads (int): Number of attention heads in the encoder.
|
||||
encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder.
|
||||
encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone.
|
||||
encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings.
|
||||
encoder_window_spec (List[int]): Window specifications for each stage of the encoder.
|
||||
checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights.
|
||||
|
||||
Returns:
|
||||
(SAM2Model): A configured and initialized SAM2 model.
|
||||
|
||||
Examples:
|
||||
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
|
||||
>>> sam2_model.eval()
|
||||
"""
|
||||
image_encoder = ImageEncoder(
|
||||
trunk=Hiera(
|
||||
embed_dim=encoder_embed_dim,
|
||||
num_heads=encoder_num_heads,
|
||||
stages=encoder_stages,
|
||||
global_att_blocks=encoder_global_att_blocks,
|
||||
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
||||
window_spec=encoder_window_spec,
|
||||
),
|
||||
neck=FpnNeck(
|
||||
d_model=256,
|
||||
backbone_channel_list=encoder_backbone_channel_list,
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
scalp=1,
|
||||
)
|
||||
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
||||
memory_encoder = MemoryEncoder(out_dim=64)
|
||||
|
||||
sam2 = SAM2Model(
|
||||
image_encoder=image_encoder,
|
||||
memory_attention=memory_attention,
|
||||
memory_encoder=memory_encoder,
|
||||
num_maskmem=7,
|
||||
image_size=1024,
|
||||
sigmoid_scale_for_mem_enc=20.0,
|
||||
sigmoid_bias_for_mem_enc=-10.0,
|
||||
use_mask_input_as_output_without_sam=True,
|
||||
directly_add_no_mem_embed=True,
|
||||
use_high_res_features_in_sam=True,
|
||||
multimask_output_in_sam=True,
|
||||
iou_prediction_use_sigmoid=True,
|
||||
use_obj_ptrs_in_encoder=True,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
only_obj_ptrs_in_the_past_for_eval=True,
|
||||
pred_obj_scores=True,
|
||||
pred_obj_scores_mlp=True,
|
||||
fixed_no_obj_ptr=True,
|
||||
multimask_output_for_tracking=True,
|
||||
use_multimask_token_for_obj_ptr=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
use_mlp_for_obj_ptr_proj=True,
|
||||
compile_image_encoder=False,
|
||||
sam_mask_decoder_extra_args=dict(
|
||||
dynamic_multimask_via_stability=True,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
),
|
||||
)
|
||||
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)["model"]
|
||||
sam2.load_state_dict(state_dict)
|
||||
sam2.eval()
|
||||
return sam2
|
||||
|
||||
|
||||
sam_model_map = {
|
||||
"sam_h.pt": build_sam_vit_h,
|
||||
"sam_l.pt": build_sam_vit_l,
|
||||
"sam_b.pt": build_sam_vit_b,
|
||||
"mobile_sam.pt": build_mobile_sam,
|
||||
"sam2_t.pt": build_sam2_t,
|
||||
"sam2_s.pt": build_sam2_s,
|
||||
"sam2_b.pt": build_sam2_b,
|
||||
"sam2_l.pt": build_sam2_l,
|
||||
}
|
||||
|
||||
|
||||
def build_sam(ckpt="sam_b.pt"):
|
||||
"""Build a SAM model specified by ckpt."""
|
||||
"""
|
||||
Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint.
|
||||
|
||||
Args:
|
||||
ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model.
|
||||
|
||||
Returns:
|
||||
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam_model = build_sam("sam_b.pt")
|
||||
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
|
||||
|
||||
Notes:
|
||||
Supported pre-defined models include:
|
||||
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
|
||||
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
|
||||
"""
|
||||
model_builder = None
|
||||
ckpt = str(ckpt) # to allow Path ckpt types
|
||||
for k in sam_model_map.keys():
|
||||
|
|
|
|||
|
|
@ -20,27 +20,46 @@ from ultralytics.engine.model import Model
|
|||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .build import build_sam
|
||||
from .predict import Predictor
|
||||
from .predict import Predictor, SAM2Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM (Segment Anything Model) interface class.
|
||||
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
||||
|
||||
SAM is designed for promptable real-time image segmentation. It can be used with a variety of prompts such as
|
||||
bounding boxes, points, or labels. The model has capabilities for zero-shot performance and is trained on the SA-1B
|
||||
dataset.
|
||||
This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for
|
||||
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
|
||||
boxes, points, or labels, and features zero-shot performance capabilities.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The loaded SAM model.
|
||||
is_sam2 (bool): Indicates whether the model is SAM2 variant.
|
||||
task (str): The task type, set to "segment" for SAM models.
|
||||
|
||||
Methods:
|
||||
predict: Performs segmentation prediction on the given image or video source.
|
||||
info: Logs information about the SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
>>> print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
||||
def __init__(self, model="sam_b.pt") -> None:
|
||||
"""
|
||||
Initializes the SAM model with a pre-trained model file.
|
||||
Initializes the SAM (Segment Anything Model) instance.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> print(sam.is_sam2)
|
||||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
|
|
@ -51,30 +70,40 @@ class SAM(Model):
|
|||
"""
|
||||
Loads the specified weights into the SAM model.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file.
|
||||
task (str, optional): Task name. Defaults to None.
|
||||
"""
|
||||
if self.is_sam2:
|
||||
from ..sam2.build import build_sam2
|
||||
This method initializes the SAM model with the provided weights file, setting up the model architecture
|
||||
and loading the pre-trained parameters.
|
||||
|
||||
self.model = build_sam2(weights)
|
||||
else:
|
||||
self.model = build_sam(weights)
|
||||
Args:
|
||||
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
||||
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> sam._load('path/to/custom_weights.pt')
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or
|
||||
a numpy.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (List[List[float]] | None): List of points for prompted segmentation.
|
||||
labels (List[int] | None): List of labels for prompted segmentation.
|
||||
**kwargs (Any): Additional keyword arguments for prediction.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
(List): The model predictions.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam.predict('image.jpg', points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
... print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
||||
kwargs.update(overrides)
|
||||
|
|
@ -83,17 +112,27 @@ class SAM(Model):
|
|||
|
||||
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Alias for the 'predict' method.
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
|
||||
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
||||
for segmentation tasks.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image
|
||||
object, or a numpy.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (List[List[float]] | None): List of points for prompted segmentation.
|
||||
labels (List[int] | None): List of labels for prompted segmentation.
|
||||
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
(List): The model predictions, typically containing segmentation masks and other relevant information.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> results = sam('image.jpg', points=[[500, 375]])
|
||||
>>> print(f"Detected {len(results[0].masks)} masks")
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
|
|
@ -101,12 +140,20 @@ class SAM(Model):
|
|||
"""
|
||||
Logs information about the SAM model.
|
||||
|
||||
This method provides details about the Segment Anything Model (SAM), including its architecture,
|
||||
parameters, and computational requirements.
|
||||
|
||||
Args:
|
||||
detailed (bool, optional): If True, displays detailed information about the model. Defaults to False.
|
||||
verbose (bool, optional): If True, displays information on the console. Defaults to True.
|
||||
detailed (bool): If True, displays detailed information about the model layers and operations.
|
||||
verbose (bool): If True, prints the information to the console.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the model's information.
|
||||
(Tuple): A tuple containing the model's information (string representations of the model).
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> info = sam.info()
|
||||
>>> print(info[0]) # Print summary information
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||
|
||||
|
|
@ -116,8 +163,13 @@ class SAM(Model):
|
|||
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||
"""
|
||||
from ..sam2.predict import SAM2Predictor
|
||||
(Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor
|
||||
class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM('sam_b.pt')
|
||||
>>> task_map = sam.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': <class 'ultralytics.models.sam.predict.Predictor'>}
|
||||
"""
|
||||
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
||||
|
|
|
|||
1131
ultralytics/models/sam/modules/blocks.py
Normal file
1131
ultralytics/models/sam/modules/blocks.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
@ -10,12 +10,14 @@ from ultralytics.nn.modules import MLP, LayerNorm2d
|
|||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores, using a transformer architecture to predict
|
||||
masks given image and prompt embeddings.
|
||||
Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
||||
|
||||
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
|
||||
generate mask predictions along with their quality scores.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): The transformer module used for mask prediction.
|
||||
transformer (nn.Module): Transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for the IoU token.
|
||||
num_mask_tokens (int): Number of mask tokens.
|
||||
|
|
@ -23,6 +25,16 @@ class MaskDecoder(nn.Module):
|
|||
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
|
||||
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
|
||||
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
||||
|
||||
Methods:
|
||||
forward: Predicts masks given image and prompt embeddings.
|
||||
predict_masks: Internal method for mask prediction.
|
||||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> masks, iou_pred = decoder(image_embeddings, image_pe, sparse_prompt_embeddings,
|
||||
... dense_prompt_embeddings, multimask_output=True)
|
||||
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -35,15 +47,20 @@ class MaskDecoder(nn.Module):
|
|||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Predicts masks given an image and prompt embeddings, using a transformer architecture.
|
||||
Initializes the MaskDecoder module for generating masks and their quality scores.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): the channel dimension of the transformer module
|
||||
transformer (nn.Module): the transformer used to predict masks
|
||||
num_multimask_outputs (int): the number of masks to predict when disambiguating masks
|
||||
activation (nn.Module): the type of activation to use when upscaling masks
|
||||
iou_head_depth (int): the depth of the MLP used to predict mask quality
|
||||
iou_head_hidden_dim (int): the hidden dimension of the MLP used to predict mask quality
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): Transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
||||
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
||||
|
||||
Examples:
|
||||
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
|
||||
>>> print(decoder)
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
|
|
@ -77,18 +94,28 @@ class MaskDecoder(nn.Module):
|
|||
multimask_output: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
Predicts masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
||||
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
||||
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
||||
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
|
||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
|
||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: batched predicted masks
|
||||
torch.Tensor: batched predictions of mask quality
|
||||
(Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- masks (torch.Tensor): Batched predicted masks.
|
||||
- iou_pred (torch.Tensor): Batched predictions of mask quality.
|
||||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> image_emb = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_emb = torch.rand(1, 2, 256)
|
||||
>>> dense_emb = torch.rand(1, 256, 64, 64)
|
||||
>>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
|
||||
>>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
|
|
@ -112,11 +139,7 @@ class MaskDecoder(nn.Module):
|
|||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predicts masks.
|
||||
|
||||
See 'forward' for more details.
|
||||
"""
|
||||
"""Predicts masks and quality scores using image and prompt embeddings via transformer architecture."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
||||
|
|
@ -147,3 +170,347 @@ class MaskDecoder(nn.Module):
|
|||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class SAM2MaskDecoder(nn.Module):
|
||||
"""
|
||||
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
||||
|
||||
This class extends the functionality of the MaskDecoder, incorporating additional features such as
|
||||
high-resolution feature processing, dynamic multimask output, and object score prediction.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for IOU token.
|
||||
num_mask_tokens (int): Total number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for mask tokens.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
obj_score_token (nn.Embedding): Embedding for object score token.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
output_upscaling (nn.Sequential): Upscaling layers for output.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
|
||||
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
|
||||
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
|
||||
iou_prediction_head (MLP): MLP for IOU prediction.
|
||||
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
||||
|
||||
Methods:
|
||||
forward: Predicts masks given image and prompt embeddings.
|
||||
predict_masks: Predicts instance segmentation masks from image and prompt embeddings.
|
||||
_get_stability_scores: Computes mask stability scores based on IoU between thresholds.
|
||||
_dynamic_multimask_via_stability: Dynamically selects the most stable mask output.
|
||||
|
||||
Examples:
|
||||
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
|
||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
use_high_res_features: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
dynamic_multimask_via_stability=False,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SAM2MaskDecoder module for predicting instance segmentation masks.
|
||||
|
||||
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
|
||||
high-resolution feature processing, dynamic multimask output, and object score prediction.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
||||
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
||||
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
|
||||
>>> print(decoder)
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
if self.pred_obj_scores:
|
||||
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.use_high_res_features = use_high_res_features
|
||||
if use_high_res_features:
|
||||
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
|
||||
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
|
||||
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
self.num_mask_tokens,
|
||||
iou_head_depth,
|
||||
sigmoid=iou_prediction_use_sigmoid,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
||||
if pred_obj_scores_mlp:
|
||||
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
||||
|
||||
# When outputting a single mask, optionally we can dynamically fall back to the best
|
||||
# multimask output token if the single mask output token gives low stability scores.
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predicts masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
|
||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
|
||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
repeat_image (bool): Flag to repeat the image embeddings.
|
||||
high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
|
||||
- iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
|
||||
- sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
|
||||
- object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
|
||||
|
||||
Examples:
|
||||
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
|
||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
repeat_image=repeat_image,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
masks = masks[:, 1:, :, :]
|
||||
iou_pred = iou_pred[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
masks = masks[:, 0:1, :, :]
|
||||
iou_pred = iou_pred[:, 0:1]
|
||||
|
||||
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
||||
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
||||
else:
|
||||
# Take the mask output token. Here we *always* use the token for single mask output.
|
||||
# At test time, even if we track after 1-click (and using multimask_output=True),
|
||||
# we still take the single mask token here. The rationale is that we always track
|
||||
# after multiple clicks during training, so the past tokens seen during training
|
||||
# are always the single mask token (and we'll let it be the object-memory token).
|
||||
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred, sam_tokens_out, object_score_logits
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts instance segmentation masks from image and prompt embeddings using a transformer."""
|
||||
# Concatenate output tokens
|
||||
s = 0
|
||||
if self.pred_obj_scores:
|
||||
output_tokens = torch.cat(
|
||||
[
|
||||
self.obj_score_token.weight,
|
||||
self.iou_token.weight,
|
||||
self.mask_tokens.weight,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
s = 1
|
||||
else:
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
if repeat_image:
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
else:
|
||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, s, :]
|
||||
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
if not self.use_high_res_features:
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
else:
|
||||
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
||||
feat_s0, feat_s1 = high_res_features
|
||||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
if self.pred_obj_scores:
|
||||
assert s == 1
|
||||
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
||||
else:
|
||||
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
||||
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
||||
|
||||
return masks, iou_pred, mask_tokens_out, object_score_logits
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""Computes mask stability scores based on IoU between upper and lower thresholds."""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
Dynamically selects the most stable mask output based on stability scores and IoU predictions.
|
||||
|
||||
This method is used when outputting a single mask. If the stability score from the current single-mask
|
||||
output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
|
||||
(based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
|
||||
for both clicking and tracking scenarios.
|
||||
|
||||
Args:
|
||||
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
|
||||
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
|
||||
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor]):
|
||||
- mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
|
||||
- iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
|
||||
|
||||
Examples:
|
||||
>>> decoder = SAM2MaskDecoder(...)
|
||||
>>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
|
||||
>>> all_iou_scores = torch.rand(2, 4)
|
||||
>>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
|
||||
>>> print(mask_logits.shape, iou_scores.shape)
|
||||
torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
|
||||
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
||||
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
||||
|
|
|
|||
|
|
@ -1,30 +1,48 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import Any, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d, MLPBlock
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
|
||||
from .blocks import (
|
||||
Block,
|
||||
CXBlock,
|
||||
Fuser,
|
||||
MaskDownSampler,
|
||||
MultiScaleBlock,
|
||||
PatchEmbed,
|
||||
PositionEmbeddingRandom,
|
||||
PositionEmbeddingSine,
|
||||
)
|
||||
|
||||
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding an image into a compact latent space. The
|
||||
encoder takes an image, splits it into patches, and processes these patches through a series of transformer blocks.
|
||||
The encoded patches are then processed through a neck to generate the final encoded representation.
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
|
||||
|
||||
This class and its supporting functions below lightly adapted from the ViTDet backbone available at
|
||||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py.
|
||||
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
|
||||
encoded representation through a neck module.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images, assumed to be square.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter, optional): Absolute positional embedding for patches.
|
||||
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
|
||||
neck (nn.Sequential): Neck module to further process the output.
|
||||
|
||||
Methods:
|
||||
forward: Processes input through patch embedding, positional embedding, blocks, and neck.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||
>>> input_image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(input_image)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -47,22 +65,38 @@ class ImageEncoderViT(nn.Module):
|
|||
global_attn_indexes: Tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
patch_size (int): Patch size.
|
||||
img_size (int): Input image size, assumed to be square.
|
||||
patch_size (int): Size of image patches.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
depth (int): Depth of ViT.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_abs_pos (bool): If True, use absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks.
|
||||
global_attn_indexes (list): Indexes for blocks using global attention.
|
||||
embed_dim (int): Dimension of patch embeddings.
|
||||
depth (int): Number of transformer blocks.
|
||||
num_heads (int): Number of attention heads in each block.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
out_chans (int): Number of output channels from the neck module.
|
||||
qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to use.
|
||||
act_layer (Type[nn.Module]): Type of activation layer to use.
|
||||
use_abs_pos (bool): If True, uses absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
|
||||
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
||||
window_size (int): Size of attention window for windowed attention blocks.
|
||||
global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks.
|
||||
neck (nn.Sequential): Neck module for final processing.
|
||||
|
||||
Examples:
|
||||
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||
>>> input_image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(input_image)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
|
@ -114,9 +148,7 @@ class ImageEncoderViT(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes input through patch embedding, applies positional embedding if present, and passes through blocks
|
||||
and neck.
|
||||
"""
|
||||
"""Processes input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
|
|
@ -127,8 +159,7 @@ class ImageEncoderViT(nn.Module):
|
|||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encodes different types of prompts, including points, boxes, and masks, for input to SAM's mask decoder. The encoder
|
||||
produces both sparse and dense embeddings for the input prompts.
|
||||
Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
|
|
@ -137,10 +168,23 @@ class PromptEncoder(nn.Module):
|
|||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not a part of any label.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
|
||||
mask_input_size (Tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
|
||||
|
||||
Methods:
|
||||
get_dense_pe: Returns the positional encoding used to encode point prompts.
|
||||
forward: Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
|
||||
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -152,18 +196,37 @@ class PromptEncoder(nn.Module):
|
|||
activation: Type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Encodes prompts for input to SAM's mask decoder.
|
||||
Initializes the PromptEncoder module for encoding various types of prompts.
|
||||
|
||||
This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder,
|
||||
producing both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The prompts' embedding dimension
|
||||
image_embedding_size (tuple(int, int)): The spatial size of the
|
||||
image embedding, as (H, W).
|
||||
input_image_size (int): The padded size of the image as input
|
||||
to the image encoder, as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for
|
||||
encoding input masks.
|
||||
activation (nn.Module): The activation to use when encoding
|
||||
input masks.
|
||||
embed_dim (int): The dimension of the embeddings.
|
||||
image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W).
|
||||
input_image_size (Tuple[int, int]): The padded size of the input image as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for encoding input masks.
|
||||
activation (Type[nn.Module]): The activation function to use when encoding input masks.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
input_image_size (Tuple[int, int]): Size of the input image as (H, W).
|
||||
image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W).
|
||||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
|
||||
mask_input_size (Tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
|
||||
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
|
@ -190,16 +253,25 @@ class PromptEncoder(nn.Module):
|
|||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the positional encoding used to encode point prompts, applied to a dense set of points the shape of the
|
||||
image encoding.
|
||||
Returns the dense positional encoding used for encoding point prompts.
|
||||
|
||||
This method generates a positional encoding for a dense set of points matching the shape of the image
|
||||
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Positional encoding with shape 1x(embed_dim)x(embedding_h)x(embedding_w)
|
||||
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
|
||||
height and width of the image embedding size, respectively.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> dense_pe = prompt_encoder.get_dense_pe()
|
||||
>>> print(dense_pe.shape)
|
||||
torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
||||
"""Embeds point prompts."""
|
||||
"""Embeds point prompts by applying positional encoding and label-specific embeddings."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
||||
|
|
@ -216,7 +288,7 @@ class PromptEncoder(nn.Module):
|
|||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds box prompts."""
|
||||
"""Embeds box prompts by applying positional encoding and adding corner embeddings."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
|
|
@ -225,7 +297,7 @@ class PromptEncoder(nn.Module):
|
|||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embeds mask inputs."""
|
||||
"""Embeds mask inputs by downscaling and processing through convolutional layers."""
|
||||
return self.mask_downscaling(masks)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -258,14 +330,25 @@ class PromptEncoder(nn.Module):
|
|||
Embeds different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (tuple(torch.Tensor, torch.Tensor), None): point coordinates and labels to embed.
|
||||
boxes (torch.Tensor, None): boxes to embed
|
||||
masks (torch.Tensor, None): masks to embed
|
||||
points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
|
||||
tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
|
||||
shape (B, N).
|
||||
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
|
||||
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: sparse embeddings for the points and boxes, with shape BxNx(embed_dim), where N is determined
|
||||
by the number of input points and boxes.
|
||||
torch.Tensor: dense embeddings for the masks, in the shape Bx(embed_dim)x(embed_H)x(embed_W)
|
||||
(Tuple[torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
|
||||
- dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
|
||||
|
||||
Examples:
|
||||
>>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_emb, dense_emb = encoder(points, boxes, masks)
|
||||
>>> print(sparse_emb.shape, dense_emb.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
|
||||
|
|
@ -287,319 +370,421 @@ class PromptEncoder(nn.Module):
|
|||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class PositionEmbeddingRandom(nn.Module):
|
||||
"""Positional encoding using random spatial frequencies."""
|
||||
class MemoryEncoder(nn.Module):
|
||||
"""
|
||||
Encodes pixel features and masks into a memory representation for efficient image segmentation.
|
||||
|
||||
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
||||
"""Initializes a position embedding using random spatial frequencies."""
|
||||
super().__init__()
|
||||
if scale is None or scale <= 0.0:
|
||||
scale = 1.0
|
||||
self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats)))
|
||||
This class processes pixel-level features and masks, fusing them to generate encoded memory representations
|
||||
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
||||
|
||||
# Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation'
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = False
|
||||
Attributes:
|
||||
mask_downsampler (MaskDownSampler): Module for downsampling input masks.
|
||||
pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
|
||||
fuser (Fuser): Module for fusing pixel features and masks.
|
||||
position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
|
||||
out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
|
||||
|
||||
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
"""Positionally encode points that are normalized to [0,1]."""
|
||||
# Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
||||
coords = 2 * coords - 1
|
||||
coords = coords @ self.positional_encoding_gaussian_matrix
|
||||
coords = 2 * np.pi * coords
|
||||
# Outputs d_1 x ... x d_n x C shape
|
||||
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
||||
Methods:
|
||||
forward: Processes input pixel features and masks to generate encoded memory representations.
|
||||
|
||||
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional encoding for a grid of the specified size."""
|
||||
h, w = size
|
||||
device: Any = self.positional_encoding_gaussian_matrix.device
|
||||
grid = torch.ones((h, w), device=device, dtype=torch.float32)
|
||||
y_embed = grid.cumsum(dim=0) - 0.5
|
||||
x_embed = grid.cumsum(dim=1) - 0.5
|
||||
y_embed = y_embed / h
|
||||
x_embed = x_embed / w
|
||||
|
||||
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
||||
return pe.permute(2, 0, 1) # C x H x W
|
||||
|
||||
def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Positionally encode points that are not normalized to [0,1]."""
|
||||
coords = coords_input.clone()
|
||||
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
||||
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
||||
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Transformer blocks with support of window attention and residual propagation blocks."""
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
|
||||
>>> pix_feat = torch.randn(1, 256, 64, 64)
|
||||
>>> masks = torch.randn(1, 1, 64, 64)
|
||||
>>> encoded_feat, pos = encoder(pix_feat, masks)
|
||||
>>> print(encoded_feat.shape, pos.shape)
|
||||
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: Type[nn.Module] = nn.GELU,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
out_dim,
|
||||
in_dim=256, # in_dim of pix_feats
|
||||
):
|
||||
"""Initializes the MemoryEncoder for encoding pixel features and masks into memory representations."""
|
||||
super().__init__()
|
||||
|
||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
||||
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
|
||||
self.out_proj = nn.Identity()
|
||||
if out_dim != in_dim:
|
||||
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pix_feat: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
skip_mask_sigmoid: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Processes pixel features and masks to generate encoded memory representations for segmentation."""
|
||||
if not skip_mask_sigmoid:
|
||||
masks = F.sigmoid(masks)
|
||||
masks = self.mask_downsampler(masks)
|
||||
|
||||
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
|
||||
pix_feat = pix_feat.to(masks.device)
|
||||
|
||||
x = self.pix_feat_proj(pix_feat)
|
||||
x = x + masks
|
||||
x = self.fuser(x)
|
||||
x = self.out_proj(x)
|
||||
|
||||
pos = self.position_encoding(x).to(x.dtype)
|
||||
|
||||
return {"vision_features": x, "vision_pos_enc": [pos]}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""
|
||||
Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.
|
||||
|
||||
This class combines a trunk network for feature extraction with a neck network for feature refinement
|
||||
and positional encoding generation. It can optionally discard the lowest resolution features.
|
||||
|
||||
Attributes:
|
||||
trunk (nn.Module): The trunk network for initial feature extraction.
|
||||
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
|
||||
scalp (int): Number of lowest resolution feature levels to discard.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input image through the trunk and neck networks.
|
||||
|
||||
Examples:
|
||||
>>> trunk = SomeTrunkNetwork()
|
||||
>>> neck = SomeNeckNetwork()
|
||||
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
|
||||
>>> image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(image)
|
||||
>>> print(output.keys())
|
||||
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trunk: nn.Module,
|
||||
neck: nn.Module,
|
||||
scalp: int = 0,
|
||||
):
|
||||
"""Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement."""
|
||||
super().__init__()
|
||||
self.trunk = trunk
|
||||
self.neck = neck
|
||||
self.scalp = scalp
|
||||
assert (
|
||||
self.trunk.channel_list == self.neck.backbone_channel_list
|
||||
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
||||
|
||||
def forward(self, sample: torch.Tensor):
|
||||
"""Encodes input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
||||
features, pos = self.neck(self.trunk(sample))
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
||||
|
||||
src = features[-1]
|
||||
output = {
|
||||
"vision_features": src,
|
||||
"vision_pos_enc": pos,
|
||||
"backbone_fpn": features,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
"""
|
||||
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
||||
similar to ViT positional embedding interpolation.
|
||||
|
||||
Attributes:
|
||||
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
|
||||
convs (nn.ModuleList): List of convolutional layers for each backbone level.
|
||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
||||
fpn_top_down_levels (List[int]): Levels to have top-down features in outputs.
|
||||
|
||||
Methods:
|
||||
forward: Performs forward pass through the FPN neck.
|
||||
|
||||
Examples:
|
||||
>>> backbone_channels = [64, 128, 256, 512]
|
||||
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
|
||||
>>> outputs, positions = fpn_neck(inputs)
|
||||
>>> print(len(outputs), len(positions))
|
||||
4 4
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
backbone_channel_list: List[int],
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
fpn_interp_model: str = "bilinear",
|
||||
fuse_type: str = "sum",
|
||||
fpn_top_down_levels: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a modified Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
||||
similar to ViT positional embedding interpolation.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads in each ViT block.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
norm_layer (nn.Module): Normalization layer.
|
||||
act_layer (nn.Module): Activation layer.
|
||||
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
window_size (int): Window size for window attention blocks. If it equals 0, then
|
||||
use global attention.
|
||||
input_size (tuple(int, int), None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
d_model (int): Dimension of the model.
|
||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
||||
kernel_size (int): Kernel size for the convolutional layers.
|
||||
stride (int): Stride for the convolutional layers.
|
||||
padding (int): Padding for the convolutional layers.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
||||
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
|
||||
|
||||
Examples:
|
||||
>>> backbone_channels = [64, 128, 256, 512]
|
||||
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
||||
>>> print(fpn_neck)
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
input_size=input_size if window_size == 0 else (window_size, window_size),
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=dim,
|
||||
out_channels=d_model,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
),
|
||||
)
|
||||
|
||||
self.convs.append(current)
|
||||
self.fpn_interp_model = fpn_interp_model
|
||||
assert fuse_type in ["sum", "avg"]
|
||||
self.fuse_type = fuse_type
|
||||
|
||||
# levels to have top-down features in its outputs
|
||||
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
||||
# have top-down propagation, while outputs of level 0 and level 1 have only
|
||||
# lateral features from the same backbone level.
|
||||
if fpn_top_down_levels is None:
|
||||
# default is to have top-down features on all levels
|
||||
fpn_top_down_levels = range(len(self.convs))
|
||||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: List[torch.Tensor]):
|
||||
"""
|
||||
Performs forward pass through the Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
|
||||
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
|
||||
|
||||
Args:
|
||||
xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing:
|
||||
- out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape
|
||||
(B, d_model, H, W).
|
||||
- pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map.
|
||||
|
||||
Examples:
|
||||
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
|
||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
|
||||
>>> outputs, positions = fpn_neck(inputs)
|
||||
>>> print(len(outputs), len(positions))
|
||||
4 4
|
||||
"""
|
||||
out = [None] * len(self.convs)
|
||||
pos = [None] * len(self.convs)
|
||||
assert len(xs) == len(self.convs)
|
||||
# fpn forward pass
|
||||
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
||||
prev_features = None
|
||||
# forward in top-down order (from low to high resolution)
|
||||
n = len(self.convs) - 1
|
||||
for i in range(n, -1, -1):
|
||||
x = xs[i]
|
||||
lateral_features = self.convs[n - i](x)
|
||||
if i in self.fpn_top_down_levels and prev_features is not None:
|
||||
top_down_features = F.interpolate(
|
||||
prev_features.to(dtype=torch.float32),
|
||||
scale_factor=2.0,
|
||||
mode=self.fpn_interp_model,
|
||||
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
||||
antialias=False,
|
||||
)
|
||||
prev_features = lateral_features + top_down_features
|
||||
if self.fuse_type == "avg":
|
||||
prev_features /= 2
|
||||
else:
|
||||
prev_features = lateral_features
|
||||
x_out = prev_features
|
||||
out[i] = x_out
|
||||
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
"""
|
||||
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
|
||||
|
||||
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
|
||||
efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
|
||||
with optional pooling and global attention mechanisms.
|
||||
|
||||
Attributes:
|
||||
window_spec (Tuple[int, ...]): Window sizes for each stage.
|
||||
q_stride (Tuple[int, int]): Downsampling stride between stages.
|
||||
stage_ends (List[int]): Indices of the last block in each stage.
|
||||
q_pool_blocks (List[int]): Indices of blocks where pooling is applied.
|
||||
return_interm_layers (bool): Whether to return intermediate layer outputs.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention.
|
||||
window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background.
|
||||
pos_embed (nn.Parameter): Positional embedding for the background.
|
||||
pos_embed_window (nn.Parameter): Positional embedding for the window.
|
||||
blocks (nn.ModuleList): List of MultiScaleBlock modules.
|
||||
channel_list (List[int]): List of output channel dimensions for each stage.
|
||||
|
||||
Methods:
|
||||
_get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings.
|
||||
forward: Performs the forward pass through the Hiera model.
|
||||
|
||||
Examples:
|
||||
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output_features = model(input_tensor)
|
||||
>>> for feat in output_features:
|
||||
... print(feat.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
drop_path_rate: float = 0.0, # stochastic depth
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
||||
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
||||
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
||||
head_mul: float = 2.0, # head_mul factor at stage shift
|
||||
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
||||
# window size per stage, when not using global att.
|
||||
window_spec: Tuple[int, ...] = (
|
||||
8,
|
||||
4,
|
||||
14,
|
||||
7,
|
||||
),
|
||||
# global attn in these blocks
|
||||
global_att_blocks: Tuple[int, ...] = (
|
||||
12,
|
||||
16,
|
||||
20,
|
||||
),
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
"""Initializes the Hiera model, configuring its hierarchical vision transformer architecture."""
|
||||
super().__init__()
|
||||
|
||||
assert len(stages) == len(window_spec)
|
||||
self.window_spec = window_spec
|
||||
|
||||
depth = sum(stages)
|
||||
self.q_stride = q_stride
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
||||
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
embed_dim=embed_dim,
|
||||
kernel_size=(7, 7),
|
||||
stride=(4, 4),
|
||||
padding=(3, 3),
|
||||
)
|
||||
# Which blocks have global att?
|
||||
self.global_att_blocks = global_att_blocks
|
||||
|
||||
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
||||
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
cur_stage = 1
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# lags by a block, so first block of
|
||||
# next stage uses an initial window size
|
||||
# of previous stage and final window size of current stage
|
||||
window_size = self.window_spec[cur_stage - 1]
|
||||
|
||||
if self.global_att_blocks is not None:
|
||||
window_size = 0 if i in self.global_att_blocks else window_size
|
||||
|
||||
if i - 1 in self.stage_ends:
|
||||
dim_out = int(embed_dim * dim_mul)
|
||||
num_heads = int(num_heads * head_mul)
|
||||
cur_stage += 1
|
||||
|
||||
block = MultiScaleBlock(
|
||||
dim=embed_dim,
|
||||
dim_out=dim_out,
|
||||
num_heads=num_heads,
|
||||
drop_path=dpr[i],
|
||||
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
embed_dim = dim_out
|
||||
self.blocks.append(block)
|
||||
|
||||
self.channel_list = (
|
||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||
if return_interm_layers
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generates positional embeddings by interpolating and combining window and background embeddings."""
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
self.window_size = window_size
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Performs forward pass through Hiera model, extracting multiscale features from input images."""
|
||||
x = self.patch_embed(x)
|
||||
# x: (B, H, W, C)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Executes a forward pass through the transformer block with window attention and non-overlapping windows."""
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
# Window partition
|
||||
if self.window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, self.window_size)
|
||||
# Add pos embed
|
||||
x = x + self._get_pos_embed(x.shape[1:3])
|
||||
|
||||
x = self.attn(x)
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
||||
outputs = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
||||
feats = x.permute(0, 3, 1, 2)
|
||||
outputs.append(feats)
|
||||
|
||||
x = shortcut + x
|
||||
return x + self.mlp(self.norm2(x))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head Attention block with relative position embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
input_size: Optional[Tuple[int, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Attention module.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
||||
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
||||
input_size (tuple(int, int), None): Input resolution for calculating the relative
|
||||
positional parameter size.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.use_rel_pos = use_rel_pos
|
||||
if self.use_rel_pos:
|
||||
assert input_size is not None, "Input size must be provided if using relative positional encoding."
|
||||
# Initialize relative positional embeddings
|
||||
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
||||
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies the forward operation including attention, normalization, MLP, and indexing within window limits."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (3, B, nHead, H * W, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v with shape (B * nHead, H * W, C)
|
||||
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
||||
|
||||
attn = (q * self.scale) @ k.transpose(-2, -1)
|
||||
|
||||
if self.use_rel_pos:
|
||||
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
"""
|
||||
Partition into non-overlapping windows with padding if needed.
|
||||
Args:
|
||||
x (tensor): input tokens with [B, H, W, C].
|
||||
window_size (int): window size.
|
||||
|
||||
Returns:
|
||||
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
||||
(Hp, Wp): padded height and width before partition
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(
|
||||
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Window unpartition into original sequences and removing padding.
|
||||
|
||||
Args:
|
||||
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
||||
window_size (int): window size.
|
||||
pad_hw (Tuple): padded height and width (Hp, Wp).
|
||||
hw (Tuple): original height and width (H, W) before padding.
|
||||
|
||||
Returns:
|
||||
x: unpartitioned sequences with [B, H, W, C].
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get relative positional embeddings according to the relative positions of query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): size of query q.
|
||||
k_size (int): size of key k.
|
||||
rel_pos (Tensor): relative position embeddings (L, C).
|
||||
|
||||
Returns:
|
||||
Extracted positional embeddings according to relative positions.
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate decomposed Relative Positional Embeddings from mvitv2 paper at
|
||||
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py.
|
||||
|
||||
Args:
|
||||
attn (Tensor): attention map.
|
||||
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
||||
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
||||
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
||||
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
attn (Tensor): attention map with added relative positional embeddings.
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w
|
||||
)
|
||||
|
||||
return attn
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Tuple[int, int] = (16, 16),
|
||||
stride: Tuple[int, int] = (16, 16),
|
||||
padding: Tuple[int, int] = (0, 0),
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize PatchEmbed module.
|
||||
|
||||
Args:
|
||||
kernel_size (Tuple): kernel size of the projection layer.
|
||||
stride (Tuple): stride of the projection layer.
|
||||
padding (Tuple): padding size of the projection layer.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Patch embedding dimension.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes patch embedding by applying convolution and transposing resulting tensor."""
|
||||
return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -6,11 +6,50 @@ from typing import Optional
|
|||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .sam2_blocks import RoPEAttention
|
||||
from .blocks import RoPEAttention
|
||||
|
||||
|
||||
class MemoryAttentionLayer(nn.Module):
|
||||
"""Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks."""
|
||||
"""
|
||||
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
||||
|
||||
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
|
||||
generate memory-based attention outputs.
|
||||
|
||||
Attributes:
|
||||
d_model (int): Dimensionality of the model.
|
||||
dim_feedforward (int): Dimensionality of the feedforward network.
|
||||
dropout_value (float): Dropout rate for regularization.
|
||||
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
|
||||
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
|
||||
linear1 (nn.Linear): First linear layer of the feedforward network.
|
||||
linear2 (nn.Linear): Second linear layer of the feedforward network.
|
||||
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
|
||||
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
|
||||
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
|
||||
dropout1 (nn.Dropout): Dropout layer after self-attention.
|
||||
dropout2 (nn.Dropout): Dropout layer after cross-attention.
|
||||
dropout3 (nn.Dropout): Dropout layer after feedforward network.
|
||||
activation (nn.ReLU): Activation function for the feedforward network.
|
||||
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
|
||||
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
|
||||
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
|
||||
|
||||
Methods:
|
||||
forward: Performs the full memory attention operation on input tensors.
|
||||
_forward_sa: Performs self-attention on input tensor.
|
||||
_forward_ca: Performs cross-attention between target and memory tensors.
|
||||
|
||||
Examples:
|
||||
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
|
||||
>>> tgt = torch.randn(1, 100, 256)
|
||||
>>> memory = torch.randn(1, 100, 64)
|
||||
>>> pos = torch.randn(1, 100, 256)
|
||||
>>> query_pos = torch.randn(1, 100, 256)
|
||||
>>> output = layer(tgt, memory, pos, query_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -21,7 +60,7 @@ class MemoryAttentionLayer(nn.Module):
|
|||
pos_enc_at_cross_attn_keys: bool = True,
|
||||
pos_enc_at_cross_attn_queries: bool = False,
|
||||
):
|
||||
"""Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components."""
|
||||
"""Initializes a memory attention layer with self-attention, cross-attention, and feedforward components."""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
|
|
@ -88,7 +127,7 @@ class MemoryAttentionLayer(nn.Module):
|
|||
query_pos: Optional[Tensor] = None,
|
||||
num_k_exclude_rope: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention."""
|
||||
"""Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention."""
|
||||
tgt = self._forward_sa(tgt, query_pos)
|
||||
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
||||
# MLP
|
||||
|
|
@ -99,7 +138,35 @@ class MemoryAttentionLayer(nn.Module):
|
|||
|
||||
|
||||
class MemoryAttention(nn.Module):
|
||||
"""Memory attention module for processing sequential data with self and cross-attention mechanisms."""
|
||||
"""
|
||||
Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
||||
|
||||
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
||||
for processing sequential data, particularly useful in transformer-like architectures.
|
||||
|
||||
Attributes:
|
||||
d_model (int): The dimension of the model's hidden state.
|
||||
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
|
||||
num_layers (int): The number of attention layers.
|
||||
norm (nn.LayerNorm): Layer normalization applied to the output.
|
||||
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
||||
batch_first (bool): Whether the input tensors are in batch-first format.
|
||||
|
||||
Methods:
|
||||
forward: Processes input tensors through the attention layers.
|
||||
|
||||
Examples:
|
||||
>>> d_model = 256
|
||||
>>> layer = MemoryAttentionLayer(d_model)
|
||||
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
||||
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
||||
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
||||
>>> curr_pos = torch.randn(10, 32, d_model)
|
||||
>>> memory_pos = torch.randn(20, 32, d_model)
|
||||
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([10, 32, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -126,7 +193,7 @@ class MemoryAttention(nn.Module):
|
|||
memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
|
||||
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
||||
):
|
||||
"""Applies self-attention and cross-attention to input tensors, processing through multiple layers."""
|
||||
"""Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms."""
|
||||
if isinstance(curr, list):
|
||||
assert isinstance(curr_pos, list)
|
||||
assert len(curr) == len(curr_pos) == 1
|
||||
|
|
@ -9,25 +9,48 @@
|
|||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from .decoders import MaskDecoder
|
||||
from ultralytics.nn.modules import MLP
|
||||
|
||||
from .blocks import SAM2TwoWayTransformer
|
||||
from .decoders import MaskDecoder, SAM2MaskDecoder
|
||||
from .encoders import ImageEncoderViT, PromptEncoder
|
||||
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
||||
|
||||
# a large negative value as a placeholder score for missing objects
|
||||
NO_OBJ_SCORE = -1024.0
|
||||
|
||||
|
||||
class SAMModel(nn.Module):
|
||||
"""
|
||||
SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
|
||||
image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
|
||||
the mask decoder to predict object masks.
|
||||
Segment Anything Model (SAM) for object segmentation tasks.
|
||||
|
||||
This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images
|
||||
and input prompts.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
|
||||
pixel_mean (List[float]): Mean pixel values for image normalization.
|
||||
pixel_std (List[float]): Standard deviation values for image normalization.
|
||||
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
||||
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
|
||||
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
|
||||
|
||||
Examples:
|
||||
>>> image_encoder = ImageEncoderViT(...)
|
||||
>>> prompt_encoder = PromptEncoder(...)
|
||||
>>> mask_decoder = MaskDecoder(...)
|
||||
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
|
||||
>>> # Further usage depends on SAMPredictor class
|
||||
|
||||
Notes:
|
||||
All forward() operations are implemented in the SAMPredictor class.
|
||||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
|
|
@ -43,17 +66,22 @@ class SAMModel(nn.Module):
|
|||
"""
|
||||
Initialize the SAMModel class to predict object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
||||
Args:
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts.
|
||||
pixel_mean (List[float], optional): Mean values for normalizing pixels in the input image. Defaults to
|
||||
(123.675, 116.28, 103.53).
|
||||
pixel_std (List[float], optional): Std values for normalizing pixels in the input image. Defaults to
|
||||
(58.395, 57.12, 57.375).
|
||||
pixel_mean (List[float]): Mean values for normalizing pixels in the input image.
|
||||
pixel_std (List[float]): Std values for normalizing pixels in the input image.
|
||||
|
||||
Examples:
|
||||
>>> image_encoder = ImageEncoderViT(...)
|
||||
>>> prompt_encoder = PromptEncoder(...)
|
||||
>>> mask_decoder = MaskDecoder(...)
|
||||
>>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder)
|
||||
>>> # Further usage depends on SAMPredictor class
|
||||
|
||||
Notes:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_encoder = image_encoder
|
||||
|
|
@ -61,3 +89,846 @@ class SAMModel(nn.Module):
|
|||
self.mask_decoder = mask_decoder
|
||||
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
|
||||
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
||||
|
||||
|
||||
class SAM2Model(torch.nn.Module):
|
||||
"""
|
||||
SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.
|
||||
|
||||
This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms
|
||||
for temporal consistency and efficient tracking of objects across frames.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_encoder (ImageEncoderViT): Visual encoder for extracting image features.
|
||||
memory_attention (nn.Module): Module for attending to memory features.
|
||||
memory_encoder (nn.Module): Encoder for generating memory representations.
|
||||
num_maskmem (int): Number of accessible memory frames.
|
||||
image_size (int): Size of input images.
|
||||
backbone_stride (int): Stride of the backbone network output.
|
||||
sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings.
|
||||
sam_image_embedding_size (int): Size of SAM image embeddings.
|
||||
sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts.
|
||||
sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks.
|
||||
obj_ptr_proj (nn.Module): Projection layer for object pointers.
|
||||
obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers.
|
||||
|
||||
Methods:
|
||||
forward_image: Processes image batch through encoder to extract multi-level features.
|
||||
track_step: Performs a single tracking step, updating object masks and memory features.
|
||||
|
||||
Examples:
|
||||
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
|
||||
>>> image_batch = torch.rand(1, 3, 512, 512)
|
||||
>>> features = model.forward_image(image_batch)
|
||||
>>> track_results = model.track_step(0, True, features, None, None, None, {})
|
||||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder,
|
||||
memory_attention,
|
||||
memory_encoder,
|
||||
num_maskmem=7,
|
||||
image_size=512,
|
||||
backbone_stride=16,
|
||||
sigmoid_scale_for_mem_enc=1.0,
|
||||
sigmoid_bias_for_mem_enc=0.0,
|
||||
binarize_mask_from_pts_for_mem_enc=False,
|
||||
use_mask_input_as_output_without_sam=False,
|
||||
max_cond_frames_in_attn=-1,
|
||||
directly_add_no_mem_embed=False,
|
||||
use_high_res_features_in_sam=False,
|
||||
multimask_output_in_sam=False,
|
||||
multimask_min_pt_num=1,
|
||||
multimask_max_pt_num=1,
|
||||
multimask_output_for_tracking=False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
memory_temporal_stride_for_eval=1,
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
non_overlap_masks_for_mem_enc=False,
|
||||
use_obj_ptrs_in_encoder=False,
|
||||
max_obj_ptrs_in_encoder=16,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
proj_tpos_enc_in_obj_ptrs=False,
|
||||
only_obj_ptrs_in_the_past_for_eval=False,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
fixed_no_obj_ptr: bool = False,
|
||||
soft_no_obj_ptr: bool = False,
|
||||
use_mlp_for_obj_ptr_proj: bool = False,
|
||||
sam_mask_decoder_extra_args=None,
|
||||
compile_image_encoder: bool = False,
|
||||
):
|
||||
"""
|
||||
Initializes the SAM2Model for video object segmentation with memory-based tracking.
|
||||
|
||||
Args:
|
||||
image_encoder (nn.Module): Visual encoder for extracting image features.
|
||||
memory_attention (nn.Module): Module for attending to memory features.
|
||||
memory_encoder (nn.Module): Encoder for generating memory representations.
|
||||
num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames).
|
||||
image_size (int): Size of input images.
|
||||
backbone_stride (int): Stride of the image backbone output.
|
||||
sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability.
|
||||
sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability.
|
||||
binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames
|
||||
with clicks during evaluation.
|
||||
use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM
|
||||
prompt encoder and mask decoder on frames with mask input.
|
||||
max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention.
|
||||
-1 means no limit.
|
||||
directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the
|
||||
first frame.
|
||||
use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder.
|
||||
multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial
|
||||
conditioning frames.
|
||||
multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM.
|
||||
multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM.
|
||||
multimask_output_for_tracking (bool): Whether to use multimask output for tracking.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
|
||||
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
|
||||
add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning
|
||||
frame list.
|
||||
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
|
||||
memory encoder during evaluation.
|
||||
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
|
||||
max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder
|
||||
cross-attention.
|
||||
add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in
|
||||
the encoder.
|
||||
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
|
||||
encoding in object pointers.
|
||||
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
|
||||
during evaluation.
|
||||
pred_obj_scores (bool): Whether to predict if there is an object in the frame.
|
||||
pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores.
|
||||
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
|
||||
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
|
||||
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
|
||||
sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
|
||||
compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
|
||||
|
||||
Examples:
|
||||
>>> image_encoder = ImageEncoderViT(...)
|
||||
>>> memory_attention = SAM2TwoWayTransformer(...)
|
||||
>>> memory_encoder = nn.Sequential(...)
|
||||
>>> model = SAM2Model(image_encoder, memory_attention, memory_encoder)
|
||||
>>> image_batch = torch.rand(1, 3, 512, 512)
|
||||
>>> features = model.forward_image(image_batch)
|
||||
>>> track_results = model.track_step(0, True, features, None, None, None, {})
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Part 1: the image backbone
|
||||
self.image_encoder = image_encoder
|
||||
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
||||
self.use_high_res_features_in_sam = use_high_res_features_in_sam
|
||||
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
|
||||
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
|
||||
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
|
||||
if use_obj_ptrs_in_encoder:
|
||||
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
|
||||
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
|
||||
# so that it can be fed into the SAM mask decoder to generate a pointer.
|
||||
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
|
||||
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
|
||||
if proj_tpos_enc_in_obj_ptrs:
|
||||
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
||||
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
||||
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
||||
|
||||
# Part 2: memory attention to condition current frame's visual features
|
||||
# with memories (and obj ptrs) from past frames
|
||||
self.memory_attention = memory_attention
|
||||
self.hidden_dim = memory_attention.d_model
|
||||
|
||||
# Part 3: memory encoder for the previous frame's outputs
|
||||
self.memory_encoder = memory_encoder
|
||||
self.mem_dim = self.hidden_dim
|
||||
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
||||
# if there is compression of memories along channel dim
|
||||
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
||||
self.num_maskmem = num_maskmem # Number of memories accessible
|
||||
# Temporal encoding of the memories
|
||||
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
||||
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
||||
# a single token to indicate no memory embedding from previous frames
|
||||
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
||||
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
||||
trunc_normal_(self.no_mem_embed, std=0.02)
|
||||
trunc_normal_(self.no_mem_pos_enc, std=0.02)
|
||||
self.directly_add_no_mem_embed = directly_add_no_mem_embed
|
||||
# Apply sigmoid to the output raw mask logits (to turn them from
|
||||
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
|
||||
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
|
||||
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
|
||||
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
|
||||
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
|
||||
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
|
||||
# On frames with mask input, whether to directly output the input mask without
|
||||
# using a SAM prompt encoder + mask decoder
|
||||
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
|
||||
self.multimask_output_in_sam = multimask_output_in_sam
|
||||
self.multimask_min_pt_num = multimask_min_pt_num
|
||||
self.multimask_max_pt_num = multimask_max_pt_num
|
||||
self.multimask_output_for_tracking = multimask_output_for_tracking
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
|
||||
|
||||
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
|
||||
# and SAM-style mask decoder for the final mask output
|
||||
self.image_size = image_size
|
||||
self.backbone_stride = backbone_stride
|
||||
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
self.pred_obj_scores_mlp = pred_obj_scores_mlp
|
||||
self.fixed_no_obj_ptr = fixed_no_obj_ptr
|
||||
self.soft_no_obj_ptr = soft_no_obj_ptr
|
||||
if self.fixed_no_obj_ptr:
|
||||
assert self.pred_obj_scores
|
||||
assert self.use_obj_ptrs_in_encoder
|
||||
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
|
||||
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
||||
trunc_normal_(self.no_obj_ptr, std=0.02)
|
||||
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
||||
|
||||
self._build_sam_heads()
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
||||
|
||||
# Model compilation
|
||||
if compile_image_encoder:
|
||||
# Compile the forward function (not the full module) to allow loading checkpoints.
|
||||
print("Image encoder compilation is enabled. First forward pass will be slow.")
|
||||
self.image_encoder.forward = torch.compile(
|
||||
self.image_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Returns the device on which the model's parameters are stored."""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Processes image and prompt inputs to generate object masks and scores in video sequences."""
|
||||
raise NotImplementedError(
|
||||
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
||||
"See notebooks/video_predictor_example.ipynb for an example."
|
||||
)
|
||||
|
||||
def _build_sam_heads(self):
|
||||
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
|
||||
self.sam_prompt_embed_dim = self.hidden_dim
|
||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
||||
|
||||
# build PromptEncoder and MaskDecoder from SAM
|
||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
self.sam_prompt_encoder = PromptEncoder(
|
||||
embed_dim=self.sam_prompt_embed_dim,
|
||||
image_embedding_size=(
|
||||
self.sam_image_embedding_size,
|
||||
self.sam_image_embedding_size,
|
||||
),
|
||||
input_image_size=(self.image_size, self.image_size),
|
||||
mask_in_chans=16,
|
||||
)
|
||||
self.sam_mask_decoder = SAM2MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=SAM2TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=self.sam_prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=self.sam_prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
use_high_res_features=self.use_high_res_features_in_sam,
|
||||
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
||||
pred_obj_scores=self.pred_obj_scores,
|
||||
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
||||
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
||||
**(self.sam_mask_decoder_extra_args or {}),
|
||||
)
|
||||
if self.use_obj_ptrs_in_encoder:
|
||||
# a linear projection on SAM output tokens to turn them into object pointers
|
||||
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
||||
if self.use_mlp_for_obj_ptr_proj:
|
||||
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
|
||||
else:
|
||||
self.obj_ptr_proj = torch.nn.Identity()
|
||||
if self.proj_tpos_enc_in_obj_ptrs:
|
||||
# a linear projection on temporal positional encoding in object pointers to
|
||||
# avoid potential interference with spatial positional encoding
|
||||
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
|
||||
else:
|
||||
self.obj_ptr_tpos_proj = torch.nn.Identity()
|
||||
|
||||
def _forward_sam_heads(
|
||||
self,
|
||||
backbone_features,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
high_res_features=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Forward pass through SAM prompt encoders and mask heads.
|
||||
|
||||
This method processes image features and optional point/mask inputs to generate object masks and scores.
|
||||
|
||||
Args:
|
||||
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
||||
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
||||
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
|
||||
pixel-unit coordinates in (x, y) format for P input points.
|
||||
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
|
||||
0 means negative clicks, and -1 means padding.
|
||||
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
||||
same spatial size as the image.
|
||||
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
|
||||
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
|
||||
for SAM decoder.
|
||||
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
|
||||
output only 1 mask and its IoU estimate.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||
|
||||
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
||||
|
||||
Examples:
|
||||
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
||||
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
||||
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
||||
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
||||
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
|
||||
"""
|
||||
B = backbone_features.size(0)
|
||||
device = backbone_features.device
|
||||
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
||||
assert backbone_features.size(2) == self.sam_image_embedding_size
|
||||
assert backbone_features.size(3) == self.sam_image_embedding_size
|
||||
|
||||
# a) Handle point prompts
|
||||
if point_inputs is not None:
|
||||
sam_point_coords = point_inputs["point_coords"]
|
||||
sam_point_labels = point_inputs["point_labels"]
|
||||
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
||||
else:
|
||||
# If no points are provide, pad with an empty point (with label -1)
|
||||
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
||||
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
||||
|
||||
# b) Handle mask prompts
|
||||
if mask_inputs is not None:
|
||||
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
||||
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
||||
sam_mask_prompt = F.interpolate(
|
||||
mask_inputs.float(),
|
||||
size=self.sam_prompt_encoder.mask_input_size,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True, # use antialias for downsampling
|
||||
)
|
||||
else:
|
||||
sam_mask_prompt = mask_inputs
|
||||
else:
|
||||
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
||||
# a learned `no_mask_embed` to indicate no mask input in this case).
|
||||
sam_mask_prompt = None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
||||
points=(sam_point_coords, sam_point_labels),
|
||||
boxes=None,
|
||||
masks=sam_mask_prompt,
|
||||
)
|
||||
(
|
||||
low_res_multimasks,
|
||||
ious,
|
||||
sam_output_tokens,
|
||||
object_score_logits,
|
||||
) = self.sam_mask_decoder(
|
||||
image_embeddings=backbone_features,
|
||||
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=False, # the image is already batched
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
is_obj_appearing = object_score_logits > 0
|
||||
|
||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||
# consistent with the actual mask prediction
|
||||
low_res_multimasks = torch.where(
|
||||
is_obj_appearing[:, None, None],
|
||||
low_res_multimasks,
|
||||
NO_OBJ_SCORE,
|
||||
)
|
||||
|
||||
# convert masks from possibly bfloat16 (or float16) to float32
|
||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||
low_res_multimasks = low_res_multimasks.float()
|
||||
high_res_multimasks = F.interpolate(
|
||||
low_res_multimasks,
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
sam_output_token = sam_output_tokens[:, 0]
|
||||
if multimask_output:
|
||||
# take the best mask prediction (with the highest IoU estimation)
|
||||
best_iou_inds = torch.argmax(ious, dim=-1)
|
||||
batch_inds = torch.arange(B, device=device)
|
||||
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
if sam_output_tokens.size(1) > 1:
|
||||
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
||||
else:
|
||||
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
||||
|
||||
# Extract object pointer from the SAM output token (with occlusion handling)
|
||||
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
||||
if self.pred_obj_scores:
|
||||
# Allow *soft* no obj ptr, unlike for masks
|
||||
if self.soft_no_obj_ptr:
|
||||
# Only hard possible with gt
|
||||
assert not self.teacher_force_obj_scores_for_mem
|
||||
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
||||
else:
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
|
||||
if self.fixed_no_obj_ptr:
|
||||
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||
|
||||
return (
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
)
|
||||
|
||||
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
||||
"""Processes mask inputs directly as output, bypassing SAM encoder/decoder."""
|
||||
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
||||
mask_inputs_float = mask_inputs.float()
|
||||
high_res_masks = mask_inputs_float * out_scale + out_bias
|
||||
low_res_masks = F.interpolate(
|
||||
high_res_masks,
|
||||
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True, # use antialias for downsampling
|
||||
)
|
||||
# a dummy IoU prediction of all 1's under mask input
|
||||
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
||||
if not self.use_obj_ptrs_in_encoder:
|
||||
# all zeros as a dummy object pointer (of shape [B, C])
|
||||
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
||||
else:
|
||||
# produce an object pointer using the SAM decoder from the mask input
|
||||
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
||||
backbone_features=backbone_features,
|
||||
mask_inputs=self.mask_downsample(mask_inputs_float),
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
||||
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
|
||||
# on the object_scores from the SAM decoder.
|
||||
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
|
||||
is_obj_appearing = is_obj_appearing[..., None]
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
|
||||
if self.pred_obj_scores:
|
||||
if self.fixed_no_obj_ptr:
|
||||
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||
|
||||
return (
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
)
|
||||
|
||||
def forward_image(self, img_batch: torch.Tensor):
|
||||
"""Processes image batch through encoder to extract multi-level features for SAM model."""
|
||||
backbone_out = self.image_encoder(img_batch)
|
||||
if self.use_high_res_features_in_sam:
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
||||
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
||||
return backbone_out
|
||||
|
||||
def _prepare_backbone_features(self, backbone_out):
|
||||
"""Prepares and flattens visual features from the image backbone output for further processing."""
|
||||
backbone_out = backbone_out.copy()
|
||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||
|
||||
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
||||
|
||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
||||
|
||||
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
||||
|
||||
def _prepare_memory_conditioned_features(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
||||
):
|
||||
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
|
||||
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||
C = self.hidden_dim
|
||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||
device = current_vision_feats[-1].device
|
||||
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
|
||||
# In this case, we skip the fusion with any memory.
|
||||
if self.num_maskmem == 0: # Disable memory and skip fusion
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat
|
||||
|
||||
num_obj_ptr_tokens = 0
|
||||
# Step 1: condition the visual features of the current frame on previous memories
|
||||
if not is_init_cond_frame:
|
||||
# Retrieve the memories encoded with the maskmem backbone
|
||||
to_cat_memory, to_cat_memory_pos_embed = [], []
|
||||
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
||||
# when getting temporal positional embedding below)
|
||||
assert len(output_dict["cond_frame_outputs"]) > 0
|
||||
# Select a maximum number of temporally closest cond frames for cross attention
|
||||
cond_outputs = output_dict["cond_frame_outputs"]
|
||||
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
|
||||
frame_idx, cond_outputs, self.max_cond_frames_in_attn
|
||||
)
|
||||
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
||||
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
||||
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
||||
# We also allow taking the memory frame non-consecutively (with r>1), in which case
|
||||
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
|
||||
r = self.memory_temporal_stride_for_eval
|
||||
for t_pos in range(1, self.num_maskmem):
|
||||
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
||||
if t_rel == 1:
|
||||
# for t_rel == 1, we take the last frame (regardless of r)
|
||||
if not track_in_reverse:
|
||||
# the frame immediately before this frame (i.e. frame_idx - 1)
|
||||
prev_frame_idx = frame_idx - t_rel
|
||||
else:
|
||||
# the frame immediately after this frame (i.e. frame_idx + 1)
|
||||
prev_frame_idx = frame_idx + t_rel
|
||||
else:
|
||||
# for t_rel >= 2, we take the memory frame from every r-th frames
|
||||
if not track_in_reverse:
|
||||
# first find the nearest frame among every r-th frames before this frame
|
||||
# for r=1, this would be (frame_idx - 2)
|
||||
prev_frame_idx = ((frame_idx - 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
||||
else:
|
||||
# first find the nearest frame among every r-th frames after this frame
|
||||
# for r=1, this would be (frame_idx + 2)
|
||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
||||
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
||||
if out is None:
|
||||
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
||||
# frames, we still attend to it as if it's a non-conditioning frame.
|
||||
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
||||
t_pos_and_prevs.append((t_pos, out))
|
||||
|
||||
for t_pos, prev in t_pos_and_prevs:
|
||||
if prev is None:
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
||||
to_cat_memory_pos_embed.append(maskmem_enc)
|
||||
|
||||
# Construct the list of past object pointers
|
||||
if self.use_obj_ptrs_in_encoder:
|
||||
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
|
||||
# First add those object pointers from selected conditioning frames
|
||||
# (optionally, only include object pointers in the past during evaluation)
|
||||
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
||||
ptr_cond_outputs = {
|
||||
t: out
|
||||
for t, out in selected_cond_outputs.items()
|
||||
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
||||
}
|
||||
else:
|
||||
ptr_cond_outputs = selected_cond_outputs
|
||||
pos_and_ptrs = [
|
||||
# Temporal pos encoding contains how far away each pointer is from current frame
|
||||
(abs(frame_idx - t), out["obj_ptr"])
|
||||
for t, out in ptr_cond_outputs.items()
|
||||
]
|
||||
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
||||
for t_diff in range(1, max_obj_ptrs_in_encoder):
|
||||
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
||||
if t < 0 or (num_frames is not None and t >= num_frames):
|
||||
break
|
||||
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
|
||||
if out is not None:
|
||||
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
||||
# If we have at least one object pointer, add them to the across attention
|
||||
if len(pos_and_ptrs) > 0:
|
||||
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
||||
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
||||
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
||||
# a temporal positional embedding based on how far each object pointer is from
|
||||
# the current frame (sine embedding normalized by the max pointer num).
|
||||
if self.add_tpos_enc_to_obj_ptrs:
|
||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||
obj_pos = torch.tensor(pos_list, device=device)
|
||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
||||
else:
|
||||
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
||||
if self.mem_dim < C:
|
||||
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
||||
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
|
||||
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
||||
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
||||
to_cat_memory.append(obj_ptrs)
|
||||
to_cat_memory_pos_embed.append(obj_pos)
|
||||
num_obj_ptr_tokens = obj_ptrs.shape[0]
|
||||
else:
|
||||
num_obj_ptr_tokens = 0
|
||||
else:
|
||||
# for initial conditioning frames, encode them without using any previous memory
|
||||
if self.directly_add_no_mem_embed:
|
||||
# directly add no-mem embedding (instead of using the transformer encoder)
|
||||
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
# Step 2: Concatenate the memories and forward through the transformer encoder
|
||||
memory = torch.cat(to_cat_memory, dim=0)
|
||||
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
||||
|
||||
pix_feat_with_mem = self.memory_attention(
|
||||
curr=current_vision_feats,
|
||||
curr_pos=current_vision_pos_embeds,
|
||||
memory=memory,
|
||||
memory_pos=memory_pos_embed,
|
||||
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
||||
)
|
||||
# reshape the output (HW)BC => BCHW
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
def _encode_new_memory(
|
||||
self,
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
pred_masks_high_res,
|
||||
is_mask_from_pts,
|
||||
):
|
||||
"""Encodes frame features and masks into a new memory representation for video segmentation."""
|
||||
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||
C = self.hidden_dim
|
||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||
# top-level feature, (HW)BC => BCHW
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||
if self.non_overlap_masks_for_mem_enc and not self.training:
|
||||
# optionally, apply non-overlapping constraints to the masks (it's applied
|
||||
# in the batch dimension and should only be used during eval, where all
|
||||
# the objects come from the same video under batch size 1).
|
||||
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
||||
# scale the raw mask logits with a temperature before applying sigmoid
|
||||
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
||||
if binarize and not self.training:
|
||||
mask_for_mem = (pred_masks_high_res > 0).float()
|
||||
else:
|
||||
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
||||
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
||||
# apply scale and bias terms to the sigmoid probabilities
|
||||
if self.sigmoid_scale_for_mem_enc != 1.0:
|
||||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(
|
||||
pix_feat,
|
||||
mask_for_mem,
|
||||
skip_mask_sigmoid=True, # sigmoid already applied
|
||||
)
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
|
||||
return maskmem_features, maskmem_pos_enc
|
||||
|
||||
def track_step(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
||||
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
||||
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
||||
# in demo we might call `track_step` multiple times for each user click,
|
||||
# and only encode the memory when the user finalizes their clicks. And in ablation
|
||||
# settings like SAM training on static images, we don't need the memory encoder.
|
||||
run_mem_encoder=True,
|
||||
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
|
||||
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
||||
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
||||
if len(current_vision_feats) > 1:
|
||||
high_res_features = [
|
||||
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
||||
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
||||
]
|
||||
else:
|
||||
high_res_features = None
|
||||
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
||||
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
||||
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
||||
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
||||
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
||||
else:
|
||||
# fused the visual feature with previous memory features in the memory bank
|
||||
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats[-1:],
|
||||
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
||||
feat_sizes=feat_sizes[-1:],
|
||||
output_dict=output_dict,
|
||||
num_frames=num_frames,
|
||||
track_in_reverse=track_in_reverse,
|
||||
)
|
||||
# apply SAM-style segmentation head
|
||||
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
||||
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
||||
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
||||
if prev_sam_mask_logits is not None:
|
||||
assert point_inputs is not None and mask_inputs is None
|
||||
mask_inputs = prev_sam_mask_logits
|
||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
||||
sam_outputs = self._forward_sam_heads(
|
||||
backbone_features=pix_feat_with_mem,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
_,
|
||||
) = sam_outputs
|
||||
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
current_out["obj_ptr"] = obj_ptr
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
if run_mem_encoder and self.num_maskmem > 0:
|
||||
high_res_masks_for_mem_enc = high_res_masks
|
||||
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks_for_mem_enc,
|
||||
is_mask_from_pts=(point_inputs is not None),
|
||||
)
|
||||
current_out["maskmem_features"] = maskmem_features
|
||||
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
current_out["maskmem_features"] = None
|
||||
current_out["maskmem_pos_enc"] = None
|
||||
|
||||
return current_out
|
||||
|
||||
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
||||
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
|
||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
||||
multimask_output = (
|
||||
self.multimask_output_in_sam
|
||||
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
||||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||
)
|
||||
return multimask_output
|
||||
|
||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
|
||||
batch_size = pred_masks.size(0)
|
||||
if batch_size == 1:
|
||||
return pred_masks
|
||||
|
||||
device = pred_masks.device
|
||||
# "max_obj_inds": object index of the object with the highest score at each location
|
||||
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
||||
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
||||
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
||||
keep = max_obj_inds == batch_obj_inds
|
||||
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
||||
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
|
|
|||
|
|
@ -17,16 +17,40 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""A sequential container that performs 2D convolution followed by batch normalization."""
|
||||
"""
|
||||
A sequential container that performs 2D convolution followed by batch normalization.
|
||||
|
||||
Attributes:
|
||||
c (torch.nn.Conv2d): 2D convolution layer.
|
||||
1 (torch.nn.BatchNorm2d): Batch normalization layer.
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the Conv2d_BN with specified parameters.
|
||||
|
||||
Args:
|
||||
a (int): Number of input channels.
|
||||
b (int): Number of output channels.
|
||||
ks (int): Kernel size for the convolution. Defaults to 1.
|
||||
stride (int): Stride for the convolution. Defaults to 1.
|
||||
pad (int): Padding for the convolution. Defaults to 0.
|
||||
dilation (int): Dilation factor for the convolution. Defaults to 1.
|
||||
groups (int): Number of groups for the convolution. Defaults to 1.
|
||||
bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1.
|
||||
|
||||
Examples:
|
||||
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output = conv_bn(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
|
||||
"""Initializes the MBConv model with given input channels, output channels, expansion ratio, activation, and
|
||||
drop path.
|
||||
"""
|
||||
"""Initializes a sequential container with 2D convolution followed by batch normalization."""
|
||||
super().__init__()
|
||||
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
|
|
@ -36,12 +60,29 @@ class Conv2d_BN(torch.nn.Sequential):
|
|||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Embeds images into patches and projects them into a specified embedding dimension."""
|
||||
"""
|
||||
Embeds images into patches and projects them into a specified embedding dimension.
|
||||
|
||||
Attributes:
|
||||
patches_resolution (Tuple[int, int]): Resolution of the patches after embedding.
|
||||
num_patches (int): Total number of patches.
|
||||
in_chans (int): Number of input channels.
|
||||
embed_dim (int): Dimension of the embedding.
|
||||
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input tensor through the patch embedding sequence.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> output = patch_embed(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, embed_dim, resolution, activation):
|
||||
"""Initialize the PatchMerging class with specified input, output dimensions, resolution and activation
|
||||
function.
|
||||
"""
|
||||
"""Initializes patch embedding with convolutional layers for image-to-patch conversion and projection."""
|
||||
super().__init__()
|
||||
img_size: Tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
|
|
@ -56,17 +97,40 @@ class PatchEmbed(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Runs input tensor 'x' through the PatchMerging model's sequence of operations."""
|
||||
"""Processes input tensor through patch embedding sequence, converting images to patch embeddings."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture."""
|
||||
"""
|
||||
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
||||
|
||||
Attributes:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
out_chans (int): Number of output channels.
|
||||
conv1 (Conv2d_BN): First convolutional layer.
|
||||
act1 (nn.Module): First activation function.
|
||||
conv2 (Conv2d_BN): Depthwise convolutional layer.
|
||||
act2 (nn.Module): Second activation function.
|
||||
conv3 (Conv2d_BN): Final convolutional layer.
|
||||
act3 (nn.Module): Third activation function.
|
||||
drop_path (nn.Module): Drop path layer (Identity for inference).
|
||||
|
||||
Methods:
|
||||
forward: Performs the forward pass through the MBConv layer.
|
||||
|
||||
Examples:
|
||||
>>> in_chans, out_chans = 32, 64
|
||||
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
|
||||
>>> x = torch.randn(1, in_chans, 56, 56)
|
||||
>>> output = mbconv(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path):
|
||||
"""Initializes a convolutional layer with specified dimensions, input resolution, depth, and activation
|
||||
function.
|
||||
"""
|
||||
"""Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation."""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
|
|
@ -86,7 +150,7 @@ class MBConv(nn.Module):
|
|||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Implements the forward pass for the model architecture."""
|
||||
"""Implements the forward pass of MBConv, applying convolutions and skip connection."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
|
|
@ -99,12 +163,34 @@ class MBConv(nn.Module):
|
|||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""Merges neighboring patches in the feature map and projects to a new dimension."""
|
||||
"""
|
||||
Merges neighboring patches in the feature map and projects to a new dimension.
|
||||
|
||||
This class implements a patch merging operation that combines spatial information and adjusts the feature
|
||||
dimension. It uses a series of convolutional layers with batch normalization to achieve this.
|
||||
|
||||
Attributes:
|
||||
input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map.
|
||||
dim (int): The input dimension of the feature map.
|
||||
out_dim (int): The output dimension after merging and projection.
|
||||
act (nn.Module): The activation function used between convolutions.
|
||||
conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
|
||||
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
|
||||
conv3 (Conv2d_BN): The third convolutional layer for final projection.
|
||||
|
||||
Methods:
|
||||
forward: Applies the patch merging operation to the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> input_resolution = (56, 56)
|
||||
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
|
||||
>>> x = torch.randn(4, 64, 56, 56)
|
||||
>>> output = patch_merging(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, out_dim, activation):
|
||||
"""Initializes the ConvLayer with specific dimension, input resolution, depth, activation, drop path, and other
|
||||
optional parameters.
|
||||
"""
|
||||
"""Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps."""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
|
|
@ -117,7 +203,7 @@ class PatchMerging(nn.Module):
|
|||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies forward pass on the input utilizing convolution and activation layers, and returns the result."""
|
||||
"""Applies patch merging and dimension projection to the input feature map."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
|
|
@ -137,7 +223,24 @@ class ConvLayer(nn.Module):
|
|||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
Optionally applies downsample operations to the output, and provides support for gradient checkpointing.
|
||||
This layer optionally applies downsample operations to the output and supports gradient checkpointing.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Resolution of the input image.
|
||||
depth (int): Number of MBConv layers in the block.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of MBConv layers.
|
||||
downsample (Optional[Callable]): Function for downsampling the output.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the convolutional layers.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
||||
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
||||
>>> output = conv_layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -155,16 +258,25 @@ class ConvLayer(nn.Module):
|
|||
"""
|
||||
Initializes the ConvLayer with the given dimensions and settings.
|
||||
|
||||
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
|
||||
optionally applies downsampling to the output.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (Callable): Activation function applied after each convolution.
|
||||
drop_path (Union[float, List[float]]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float): Expansion ratio for the MBConv layers.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
||||
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
||||
>>> output = conv_layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -194,7 +306,7 @@ class ConvLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Processes the input through a series of convolutional layers and returns the activated output."""
|
||||
"""Processes input through convolutional layers, applying MBConv blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
|
@ -202,13 +314,33 @@ class ConvLayer(nn.Module):
|
|||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) for transformer architectures.
|
||||
Multi-layer Perceptron (MLP) module for transformer architectures.
|
||||
|
||||
This layer takes an input with in_features, applies layer normalization and two fully-connected layers.
|
||||
This module applies layer normalization, two fully-connected layers with an activation function in between,
|
||||
and dropout. It is commonly used in transformer-based architectures.
|
||||
|
||||
Attributes:
|
||||
norm (nn.LayerNorm): Layer normalization applied to the input.
|
||||
fc1 (nn.Linear): First fully-connected layer.
|
||||
fc2 (nn.Linear): Second fully-connected layer.
|
||||
act (nn.Module): Activation function applied after the first fully-connected layer.
|
||||
drop (nn.Dropout): Dropout layer applied after the activation function.
|
||||
|
||||
Methods:
|
||||
forward: Applies the MLP operations on the input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
>>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1)
|
||||
>>> x = torch.randn(32, 100, 256)
|
||||
>>> output = mlp(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([32, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
"""Initializes Attention module with the given parameters including dimension, key_dim, number of heads, etc."""
|
||||
"""Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions."""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
|
|
@ -219,7 +351,7 @@ class Mlp(nn.Module):
|
|||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies operations on input x and returns modified x, runs downsample if not None."""
|
||||
"""Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
|
|
@ -230,12 +362,37 @@ class Mlp(nn.Module):
|
|||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with support for spatial awareness, applying attention biases based on spatial
|
||||
resolution. Implements trainable attention biases for each unique offset between spatial positions in the resolution
|
||||
grid.
|
||||
Multi-head attention module with spatial awareness and trainable attention biases.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Attributes:
|
||||
ab (Tensor, optional): Cached attention biases for inference, deleted during training.
|
||||
num_heads (int): Number of attention heads.
|
||||
scale (float): Scaling factor for attention scores.
|
||||
key_dim (int): Dimensionality of the keys and queries.
|
||||
nh_kd (int): Product of num_heads and key_dim.
|
||||
d (int): Dimensionality of the value vectors.
|
||||
dh (int): Product of d and num_heads.
|
||||
attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
|
||||
norm (nn.LayerNorm): Layer normalization applied to input.
|
||||
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
|
||||
proj (nn.Linear): Linear layer for final projection.
|
||||
attention_biases (nn.Parameter): Learnable attention biases.
|
||||
attention_bias_idxs (Tensor): Indices for attention biases.
|
||||
ab (Tensor): Cached attention biases for inference, deleted during training.
|
||||
|
||||
Methods:
|
||||
train: Sets the module in training mode and handles the 'ab' attribute.
|
||||
forward: Performs the forward pass of the attention mechanism.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
||||
>>> x = torch.randn(1, 196, 256)
|
||||
>>> output = attn(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -247,17 +404,28 @@ class Attention(torch.nn.Module):
|
|||
resolution=(14, 14),
|
||||
):
|
||||
"""
|
||||
Initializes the Attention module.
|
||||
Initializes the Attention module for multi-head attention with spatial awareness.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads. Default is 8.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int], optional): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
num_heads (int): Number of attention heads. Default is 8.
|
||||
attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4.
|
||||
resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14).
|
||||
|
||||
Raises:
|
||||
AssertionError: If `resolution` is not a tuple of length 2.
|
||||
AssertionError: If 'resolution' is not a tuple of length 2.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
||||
>>> x = torch.randn(1, 196, 256)
|
||||
>>> output = attn(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 256])
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -290,7 +458,7 @@ class Attention(torch.nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def train(self, mode=True):
|
||||
"""Sets the module in training mode and handles attribute 'ab' based on the mode."""
|
||||
"""Performs multi-head attention with spatial awareness and trainable attention biases."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, "ab"):
|
||||
del self.ab
|
||||
|
|
@ -298,7 +466,7 @@ class Attention(torch.nn.Module):
|
|||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x
|
||||
"""Performs forward pass over the input tensor 'x' by applying normalization and querying keys/values."""
|
||||
"""Applies multi-head attention with spatial awareness and trainable attention biases."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
|
|
@ -322,7 +490,34 @@ class Attention(torch.nn.Module):
|
|||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""TinyViT Block that applies self-attention and a local convolution to the input."""
|
||||
"""
|
||||
TinyViT Block that applies self-attention and a local convolution to the input.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Size of the attention window.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
|
||||
attn (Attention): Self-attention module.
|
||||
mlp (Mlp): Multi-layer perceptron module.
|
||||
local_conv (Conv2d_BN): Depth-wise local convolution layer.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the TinyViT block.
|
||||
extra_repr: Returns a string with extra information about the block's parameters.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 196, 192)
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
||||
>>> output = block(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 192])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -337,22 +532,32 @@ class TinyViTBlock(nn.Module):
|
|||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initializes the TinyViTBlock.
|
||||
Initializes a TinyViT block with self-attention and local convolution.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Window size for attention. Default is 7.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float, optional): Stochastic depth rate. Default is 0.
|
||||
local_conv_size (int, optional): The kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
window_size (int): Size of the attention window. Must be greater than 0.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float): Dropout rate.
|
||||
drop_path (float): Stochastic depth rate.
|
||||
local_conv_size (int): Kernel size of the local convolution.
|
||||
activation (torch.nn.Module): Activation function for MLP.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `window_size` is not greater than 0.
|
||||
AssertionError: If `dim` is not divisible by `num_heads`.
|
||||
AssertionError: If window_size is not greater than 0.
|
||||
AssertionError: If dim is not divisible by num_heads.
|
||||
|
||||
Examples:
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
||||
>>> input_tensor = torch.randn(1, 196, 192)
|
||||
>>> output = block(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 192])
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -380,9 +585,7 @@ class TinyViTBlock(nn.Module):
|
|||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies attention-based transformation or padding to input 'x' before passing it through a local
|
||||
convolution.
|
||||
"""
|
||||
"""Applies self-attention, local convolution, and MLP operations to the input tensor."""
|
||||
h, w = self.input_resolution
|
||||
b, hw, c = x.shape # batch, height*width, channels
|
||||
assert hw == h * w, "input feature has wrong size"
|
||||
|
|
@ -424,8 +627,19 @@ class TinyViTBlock(nn.Module):
|
|||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a formatted string representing the TinyViTBlock's parameters: dimension, input resolution, number of
|
||||
attentions heads, window size, and MLP ratio.
|
||||
"""
|
||||
Returns a string representation of the TinyViTBlock's parameters.
|
||||
|
||||
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
||||
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
||||
|
||||
Returns:
|
||||
(str): A formatted string containing the block's parameters.
|
||||
|
||||
Examples:
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
|
||||
>>> print(block.extra_repr())
|
||||
dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
|
||||
"""
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
|
|
@ -434,7 +648,31 @@ class TinyViTBlock(nn.Module):
|
|||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""A basic TinyViT layer for one stage in a TinyViT architecture."""
|
||||
"""
|
||||
A basic TinyViT layer for one stage in a TinyViT architecture.
|
||||
|
||||
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
|
||||
and an optional downsampling operation.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
|
||||
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input through the layer's blocks and optional downsampling.
|
||||
extra_repr: Returns a string with the layer's parameters for printing.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 3136, 192)
|
||||
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> output = layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 784, 384])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -453,25 +691,34 @@ class BasicLayer(nn.Module):
|
|||
out_dim=None,
|
||||
):
|
||||
"""
|
||||
Initializes the BasicLayer.
|
||||
Initializes a BasicLayer in the TinyViT architecture.
|
||||
|
||||
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
|
||||
process feature maps at a specific resolution and dimensionality within the TinyViT model.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float, optional): Ratio of mlp hidden dim to embedding dim. Default is 4.
|
||||
drop (float, optional): Dropout rate. Default is 0.
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default is 0.
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default is None.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory. Default is False.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution. Default is 3.
|
||||
activation (torch.nn, optional): Activation function for MLP. Default is nn.GELU.
|
||||
out_dim (int | None, optional): The output dimension of the layer. Default is None.
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
num_heads (int): Number of attention heads in each TinyViT block.
|
||||
window_size (int): Size of the local window for attention computation.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float): Dropout rate.
|
||||
drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block.
|
||||
downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
local_conv_size (int): Kernel size for the local convolution in each TinyViT block.
|
||||
activation (nn.Module): Activation function used in the MLP.
|
||||
out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `drop_path` is a list of float but its length doesn't match `depth`.
|
||||
ValueError: If `drop_path` is a list and its length doesn't match `depth`.
|
||||
|
||||
Examples:
|
||||
>>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> x = torch.randn(1, 56*56, 96)
|
||||
>>> output = layer(x)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -505,58 +752,49 @@ class BasicLayer(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""Performs forward propagation on the input tensor and returns a normalized tensor."""
|
||||
"""Processes input through TinyViT blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x)
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Returns a string representation of the extra_repr function with the layer's parameters."""
|
||||
"""Returns a string with the layer's parameters for printing."""
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
"""A PyTorch implementation of Layer Normalization in 2D."""
|
||||
|
||||
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
||||
"""Initialize LayerNorm2d with the number of channels and an optional epsilon."""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(num_channels))
|
||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform a forward pass, normalizing the input tensor."""
|
||||
u = x.mean(1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.eps)
|
||||
return self.weight[:, None, None] * x + self.bias[:, None, None]
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
The TinyViT architecture for vision tasks.
|
||||
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
||||
|
||||
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
|
||||
neural networks for improved efficiency and performance on vision tasks.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
in_chans (int): Number of input channels.
|
||||
num_classes (int): Number of classification classes.
|
||||
embed_dims (List[int]): List of embedding dimensions for each layer.
|
||||
depths (List[int]): List of depths for each layer.
|
||||
num_heads (List[int]): List of number of attention heads for each layer.
|
||||
window_sizes (List[int]): List of window sizes for each layer.
|
||||
depths (List[int]): Number of blocks in each stage.
|
||||
num_layers (int): Total number of layers in the network.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_rate (float): Dropout rate for drop layers.
|
||||
drop_path_rate (float): Drop path rate for stochastic depth.
|
||||
use_checkpoint (bool): Use checkpointing for efficient memory usage.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int): Local convolution kernel size.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
patches_resolution (Tuple[int, int]): Resolution of embedded patches.
|
||||
layers (nn.ModuleList): List of network layers.
|
||||
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
|
||||
head (nn.Linear): Linear layer for final classification.
|
||||
neck (nn.Sequential): Neck module for feature refinement.
|
||||
|
||||
Note:
|
||||
This implementation is generalized to accept a list of depths, attention heads,
|
||||
embedding dimensions and window sizes, which allows you to create a
|
||||
"stack" of TinyViT models of varying configurations.
|
||||
Methods:
|
||||
set_layer_lr_decay: Sets layer-wise learning rate decay.
|
||||
_init_weights: Initializes weights for linear and normalization layers.
|
||||
no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay.
|
||||
forward_features: Processes input through the feature extraction layers.
|
||||
forward: Performs a forward pass through the entire network.
|
||||
|
||||
Examples:
|
||||
>>> model = TinyViT(img_size=224, num_classes=1000)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> features = model.forward_features(x)
|
||||
>>> print(features.shape)
|
||||
torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -579,21 +817,33 @@ class TinyViT(nn.Module):
|
|||
"""
|
||||
Initializes the TinyViT model.
|
||||
|
||||
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
|
||||
attention and convolution blocks, and a classification head.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): The input image size. Defaults to 224.
|
||||
in_chans (int, optional): Number of input channels. Defaults to 3.
|
||||
num_classes (int, optional): Number of classification classes. Defaults to 1000.
|
||||
embed_dims (List[int], optional): List of embedding dimensions per layer. Defaults to [96, 192, 384, 768].
|
||||
depths (List[int], optional): List of depths for each layer. Defaults to [2, 2, 6, 2].
|
||||
num_heads (List[int], optional): List of number of attention heads per layer. Defaults to [3, 6, 12, 24].
|
||||
window_sizes (List[int], optional): List of window sizes for each layer. Defaults to [7, 7, 14, 7].
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension. Defaults to 4.
|
||||
drop_rate (float, optional): Dropout rate. Defaults to 0.
|
||||
drop_path_rate (float, optional): Drop path rate for stochastic depth. Defaults to 0.1.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing for efficient memory usage. Defaults to False.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer. Defaults to 4.0.
|
||||
local_conv_size (int, optional): Local convolution kernel size. Defaults to 3.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay. Defaults to 1.0.
|
||||
img_size (int): Size of the input image. Default is 224.
|
||||
in_chans (int): Number of input channels. Default is 3.
|
||||
num_classes (int): Number of classes for classification. Default is 1000.
|
||||
embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage.
|
||||
Default is (96, 192, 384, 768).
|
||||
depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2).
|
||||
num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage.
|
||||
Default is (3, 6, 12, 24).
|
||||
window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7).
|
||||
mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0.
|
||||
drop_rate (float): Dropout rate. Default is 0.0.
|
||||
drop_path_rate (float): Stochastic depth rate. Default is 0.1.
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False.
|
||||
mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0.
|
||||
local_conv_size (int): Kernel size for local convolutions. Default is 3.
|
||||
layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0.
|
||||
|
||||
Examples:
|
||||
>>> model = TinyViT(img_size=224, num_classes=1000)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> output = model(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 1000])
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
|
@ -671,7 +921,7 @@ class TinyViT(nn.Module):
|
|||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay):
|
||||
"""Sets the learning rate decay for each layer in the TinyViT model."""
|
||||
"""Sets layer-wise learning rate decay for the TinyViT model based on depth."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# Layers -> blocks (depth)
|
||||
|
|
@ -706,7 +956,7 @@ class TinyViT(nn.Module):
|
|||
self.apply(_check_lr_scale)
|
||||
|
||||
def _init_weights(self, m):
|
||||
"""Initializes weights for linear layers and layer normalization in the given module."""
|
||||
"""Initializes weights for linear and normalization layers in the TinyViT model."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
|
|
@ -718,11 +968,11 @@ class TinyViT(nn.Module):
|
|||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
"""Returns a dictionary of parameter names where weight decay should not be applied."""
|
||||
"""Returns a set of keywords for parameters that should not use weight decay."""
|
||||
return {"attention_biases"}
|
||||
|
||||
def forward_features(self, x):
|
||||
"""Runs the input through the model layers and returns the transformed output."""
|
||||
"""Processes input through feature extraction layers, returning spatial features."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
|
|
@ -737,5 +987,5 @@ class TinyViT(nn.Module):
|
|||
return self.neck(x)
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes a forward pass on the input tensor through the constructed model layers."""
|
||||
"""Performs the forward pass through the TinyViT model, extracting features from the input image."""
|
||||
return self.forward_features(x)
|
||||
|
|
|
|||
|
|
@ -11,19 +11,31 @@ from ultralytics.nn.modules import MLPBlock
|
|||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module that enables the simultaneous attention to both image and query points. This class
|
||||
serves as a specialized transformer decoder that attends to an input image using queries whose positional embedding
|
||||
is supplied. This is particularly useful for tasks like object detection, image segmentation, and point cloud
|
||||
processing.
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer decoder that attends to an input image using queries with
|
||||
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
|
||||
cloud processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): The number of layers in the transformer.
|
||||
embedding_dim (int): The channel dimension for the input embeddings.
|
||||
num_heads (int): The number of heads for multihead attention.
|
||||
mlp_dim (int): The internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): The list of TwoWayAttentionBlock layers that make up the transformer.
|
||||
final_attn_token_to_image (Attention): The final attention layer applied from the queries to the image.
|
||||
norm_final_attn (nn.LayerNorm): The layer normalization applied to the final queries.
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Methods:
|
||||
forward: Processes image and point embeddings through the transformer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -36,15 +48,33 @@ class TwoWayTransformer(nn.Module):
|
|||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer decoder that attends to an input image using queries whose positional embedding is supplied.
|
||||
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
||||
|
||||
Args:
|
||||
depth (int): number of layers in the transformer
|
||||
embedding_dim (int): the channel dimension for the input embeddings
|
||||
num_heads (int): the number of heads for multihead attention. Must
|
||||
divide embedding_dim
|
||||
mlp_dim (int): the channel dimension internal to the MLP block
|
||||
activation (nn.Module): the activation to use in the MLP block
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
||||
attention_downsample_rate (int): Downsampling rate for attention mechanism.
|
||||
|
||||
Attributes:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
|
|
@ -75,15 +105,23 @@ class TwoWayTransformer(nn.Module):
|
|||
point_embedding: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Processes image and point embeddings through the Two-Way Transformer.
|
||||
|
||||
Args:
|
||||
image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w.
|
||||
image_pe (torch.Tensor): the positional encoding to add to the image. Must have same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): the embedding to add to the query points.
|
||||
Must have shape B x N_points x embedding_dim for any N_points.
|
||||
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
||||
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): the processed point_embedding
|
||||
(torch.Tensor): the processed image_embedding
|
||||
(Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
|
|
@ -114,21 +152,34 @@ class TwoWayTransformer(nn.Module):
|
|||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that performs both self-attention and cross-attention in two directions: queries to keys and
|
||||
keys to queries. This block consists of four main layers: (1) self-attention on sparse inputs, (2) cross-attention
|
||||
of sparse inputs to dense inputs, (3) an MLP block on sparse inputs, and (4) cross-attention of dense inputs to
|
||||
sparse inputs.
|
||||
A two-way attention block for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
||||
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
||||
inputs to sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
self_attn (Attention): Self-attention layer for queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLPBlock): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
|
||||
mlp (MLPBlock): MLP block for transforming query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization after MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
||||
|
||||
Methods:
|
||||
forward: Applies self-attention and cross-attention to queries and keys.
|
||||
|
||||
Examples:
|
||||
>>> embedding_dim, num_heads = 256, 8
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
||||
>>> queries = torch.randn(1, 100, embedding_dim)
|
||||
>>> keys = torch.randn(1, 1000, embedding_dim)
|
||||
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
||||
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
||||
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -141,16 +192,28 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse
|
||||
inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
||||
inputs.
|
||||
Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
||||
|
||||
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
||||
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
||||
of dense inputs to sparse inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): the channel dimension of the embeddings
|
||||
num_heads (int): the number of heads in the attention layers
|
||||
mlp_dim (int): the hidden dimension of the mlp block
|
||||
activation (nn.Module): the activation of the mlp block
|
||||
skip_first_layer_pe (bool): skip the PE on the first layer
|
||||
embedding_dim (int): Channel dimension of the embeddings.
|
||||
num_heads (int): Number of attention heads in the attention layers.
|
||||
mlp_dim (int): Hidden dimension of the MLP block.
|
||||
activation (Type[nn.Module]): Activation function for the MLP block.
|
||||
attention_downsample_rate (int): Downsampling rate for the attention mechanism.
|
||||
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
||||
|
||||
Examples:
|
||||
>>> embedding_dim, num_heads = 256, 8
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
||||
>>> queries = torch.randn(1, 100, embedding_dim)
|
||||
>>> keys = torch.randn(1, 1000, embedding_dim)
|
||||
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
||||
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
||||
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
|
|
@ -168,7 +231,7 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Apply self-attention and cross-attention to queries and keys and return the processed embeddings."""
|
||||
"""Applies two-way attention to process query and key embeddings in a transformer block."""
|
||||
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
|
|
@ -202,8 +265,34 @@ class TwoWayAttentionBlock(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""An attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
|
||||
values.
|
||||
"""
|
||||
An attention layer with downscaling capability for embedding size after projection.
|
||||
|
||||
This class implements a multi-head attention mechanism with the option to downsample the internal
|
||||
dimension of queries, keys, and values.
|
||||
|
||||
Attributes:
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
kv_in_dim (int): Dimensionality of key and value inputs.
|
||||
internal_dim (int): Internal dimension after downsampling.
|
||||
num_heads (int): Number of attention heads.
|
||||
q_proj (nn.Linear): Linear projection for queries.
|
||||
k_proj (nn.Linear): Linear projection for keys.
|
||||
v_proj (nn.Linear): Linear projection for values.
|
||||
out_proj (nn.Linear): Linear projection for output.
|
||||
|
||||
Methods:
|
||||
_separate_heads: Separates input tensor into attention heads.
|
||||
_recombine_heads: Recombines separated attention heads.
|
||||
forward: Computes attention output for given query, key, and value tensors.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
||||
>>> q = torch.randn(1, 100, 256)
|
||||
>>> k = v = torch.randn(1, 50, 256)
|
||||
>>> output = attn(q, k, v)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -214,15 +303,27 @@ class Attention(nn.Module):
|
|||
kv_in_dim: int = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
Initializes the Attention module with specified dimensions and settings.
|
||||
|
||||
This class implements a multi-head attention mechanism with optional downsampling of the internal
|
||||
dimension for queries, keys, and values.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The dimensionality of the input embeddings.
|
||||
num_heads (int): The number of attention heads.
|
||||
downsample_rate (int, optional): The factor by which the internal dimensions are downsampled. Defaults to 1.
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1.
|
||||
kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
||||
|
||||
Raises:
|
||||
AssertionError: If 'num_heads' does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
||||
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
||||
>>> q = torch.randn(1, 100, 256)
|
||||
>>> k = v = torch.randn(1, 50, 256)
|
||||
>>> output = attn(q, k, v)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
|
|
@ -238,20 +339,20 @@ class Attention(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _separate_heads(x: Tensor, num_heads: int) -> Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
"""Separates the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine the separated attention heads into a single tensor."""
|
||||
"""Recombines separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
||||
"""Compute the attention output given the input query, key, and value tensors."""
|
||||
"""Applies multi-head attention to query, key, and value tensors with optional downsampling."""
|
||||
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
|
@ -70,7 +72,7 @@ def get_1d_sine_pe(pos_inds, dim, temperature=10000):
|
|||
|
||||
|
||||
def init_t_xy(end_x: int, end_y: int):
|
||||
"""Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y."""
|
||||
"""Initializes 1D and 2D coordinate tensors for a grid of specified dimensions."""
|
||||
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
||||
t_x = (t % end_x).float()
|
||||
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
||||
|
|
@ -78,7 +80,7 @@ def init_t_xy(end_x: int, end_y: int):
|
|||
|
||||
|
||||
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
||||
"""Computes axial complex exponential positional encodings for 2D spatial positions."""
|
||||
"""Computes axial complex exponential positional encodings for 2D spatial positions in a grid."""
|
||||
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
|
||||
|
|
@ -91,7 +93,7 @@ def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
|||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
"""Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions."""
|
||||
"""Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility."""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
||||
|
|
@ -189,3 +191,103 @@ def window_unpartition(windows, window_size, pad_hw, hw):
|
|||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extracts relative positional embeddings based on query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): Size of the query.
|
||||
k_size (int): Size of the key.
|
||||
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
||||
distance and C is the embedding dimension.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
||||
k_size, C).
|
||||
|
||||
Examples:
|
||||
>>> q_size, k_size = 8, 16
|
||||
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
|
||||
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
|
||||
>>> print(extracted_pos.shape)
|
||||
torch.Size([8, 16, 64])
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Adds decomposed Relative Positional Embeddings to the attention map.
|
||||
|
||||
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
||||
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
||||
positions.
|
||||
|
||||
Args:
|
||||
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
|
||||
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
|
||||
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
|
||||
q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
|
||||
k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
||||
(B, q_h * q_w, k_h * k_w).
|
||||
|
||||
Examples:
|
||||
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
||||
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
|
||||
>>> q = torch.rand(B, q_h * q_w, C)
|
||||
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
|
||||
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
|
||||
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
|
||||
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
|
||||
>>> print(updated_attn.shape)
|
||||
torch.Size([1, 64, 64])
|
||||
|
||||
References:
|
||||
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w
|
||||
)
|
||||
|
||||
return attn
|
||||
|
|
@ -34,35 +34,64 @@ from .build import build_sam
|
|||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
||||
|
||||
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||
points, and low-resolution masks.
|
||||
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
|
||||
segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
|
||||
fine-grained control over segmentation results.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
args (SimpleNamespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded SAM model.
|
||||
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
||||
im (torch.Tensor): The preprocessed input image.
|
||||
features (torch.Tensor): Extracted image features.
|
||||
prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
|
||||
segment_all (bool): Flag to indicate if full image segmentation should be performed.
|
||||
mean (torch.Tensor): Mean values for image normalization.
|
||||
std (torch.Tensor): Standard deviation values for image normalization.
|
||||
|
||||
Methods:
|
||||
preprocess: Prepares input images for model inference.
|
||||
pre_transform: Performs initial transformations on the input image.
|
||||
inference: Performs segmentation inference based on input prompts.
|
||||
prompt_inference: Internal function for prompt-based segmentation inference.
|
||||
generate: Generates segmentation masks for an entire image.
|
||||
setup_model: Initializes the SAM model for inference.
|
||||
get_model: Builds and returns a SAM model.
|
||||
postprocess: Post-processes model outputs to generate final results.
|
||||
setup_source: Sets up the data source for inference.
|
||||
set_image: Sets and preprocesses a single image for inference.
|
||||
get_im_features: Extracts image features using the SAM image encoder.
|
||||
set_prompts: Sets prompts for subsequent inference.
|
||||
reset_image: Resets the current image and its features.
|
||||
remove_small_regions: Removes small disconnected regions and holes from masks.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> masks, scores, boxes = predictor.generate()
|
||||
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
|
||||
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
|
||||
for optimal results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||
cfg (Dict): Configuration dictionary containing default settings.
|
||||
overrides (Dict | None): Dictionary of values to override default configuration.
|
||||
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = Predictor(overrides={'imgsz': 640})
|
||||
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
|
@ -78,14 +107,19 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Preprocess the input image for model inference.
|
||||
|
||||
The method prepares the input image by applying transformations and normalization.
|
||||
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||
This method prepares the input image by applying transformations and normalization. It supports both
|
||||
torch.Tensor and list of np.ndarray as input formats.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor.
|
||||
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> image = torch.rand(1, 3, 640, 640)
|
||||
>>> preprocessed_image = predictor.preprocess(image)
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
|
|
@ -106,14 +140,24 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
This method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
Currently, batched inference is not supported; hence the list length should be 1.
|
||||
|
||||
Args:
|
||||
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||
im (List[np.ndarray]): List containing a single image in HWC numpy array format.
|
||||
|
||||
Returns:
|
||||
(List[np.ndarray]): List of transformed images.
|
||||
(List[np.ndarray]): List containing the transformed image.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the input list contains more than one image.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> image = np.random.rand(480, 640, 3) # Single HWC image
|
||||
>>> transformed = predictor.pre_transform([image])
|
||||
>>> print(len(transformed))
|
||||
1
|
||||
"""
|
||||
assert len(im) == 1, "SAM model does not currently support batched inference"
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
|
|
@ -121,23 +165,32 @@ class Predictor(BasePredictor):
|
|||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
||||
|
||||
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
|
||||
encoder, and mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
*args (Any): Additional positional arguments.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
(tuple): Contains the following three elements:
|
||||
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
|
|
@ -151,22 +204,30 @@ class Predictor(BasePredictor):
|
|||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
Performs image segmentation inference based on input cues using SAM's specialized architecture.
|
||||
|
||||
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
|
||||
It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
|
||||
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> im = torch.rand(1, 3, 1024, 1024)
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
|
||||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
|
|
@ -224,27 +285,32 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||
Each layer produces 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||
Used in the nth crop layer.
|
||||
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||
Exclusive with 'point_grids'.
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
|
||||
crop_n_layers (int): Number of layers for additional mask predictions on image crops.
|
||||
crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
|
||||
point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
|
||||
points_stride (int): Number of points to sample along each side of the image.
|
||||
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
|
||||
stability_score_offset (float): Offset value for calculating stability score.
|
||||
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> im = torch.rand(1, 3, 1024, 1024) # Example input image
|
||||
>>> masks, scores, boxes = predictor.generate(im)
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
|
|
@ -326,11 +392,9 @@ class Predictor(BasePredictor):
|
|||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||
device (torch.device): The device to which the model and tensors are allocated.
|
||||
mean (torch.Tensor): The mean values for image normalization.
|
||||
std (torch.Tensor): The standard deviation values for image normalization.
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model=sam_model, verbose=True)
|
||||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
|
|
@ -349,23 +413,32 @@ class Predictor(BasePredictor):
|
|||
self.done_warmup = True
|
||||
|
||||
def get_model(self):
|
||||
"""Built Segment Anything Model (SAM) model."""
|
||||
"""Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
|
||||
The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||
This method scales masks and boxes to the original image size and applies a threshold to the mask
|
||||
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||
img (torch.Tensor): The processed input image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||
preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
|
||||
- pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
|
||||
- pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
|
||||
img (torch.Tensor): The processed input image tensor with shape (C, H, W).
|
||||
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
metadata for each processed image.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> preds = predictor.inference(img)
|
||||
>>> results = predictor.postprocess(preds, img, orig_imgs)
|
||||
"""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
|
|
@ -393,11 +466,23 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Sets up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||
directory, a video file, or other types of image data sources.
|
||||
This method configures the data source from which images will be fetched for inference. It supports
|
||||
various input types such as image files, directories, video files, and other compatible data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path): The path to the image data source for inference.
|
||||
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
|
||||
directory path, URL, or other supported source types.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_source('path/to/images')
|
||||
>>> predictor.setup_source('video.mp4')
|
||||
>>> predictor.setup_source(None) # Uses default source if available
|
||||
|
||||
Notes:
|
||||
- If source is None, the method may use a default source if configured.
|
||||
- The method adapts to different source types and prepares them for subsequent inference steps.
|
||||
- Supported source types may include local files, directories, URLs, and video streams.
|
||||
"""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
|
@ -406,14 +491,25 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
This method prepares the model for inference on a single image by setting up the model if not already
|
||||
initialized, configuring the data source, and preprocessing the image for feature extraction. It
|
||||
ensures that only one image is set at a time and extracts image features for subsequent use.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
|
||||
an image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
AssertionError: If more than one image is attempted to be set.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.set_image('path/to/image.jpg')
|
||||
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
|
||||
|
||||
Notes:
|
||||
- This method should be called before performing inference on a new image.
|
||||
- The extracted features are stored in the `self.features` attribute for later use.
|
||||
"""
|
||||
if self.model is None:
|
||||
self.setup_model(model=None)
|
||||
|
|
@ -425,35 +521,44 @@ class Predictor(BasePredictor):
|
|||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Get image features from the SAM image encoder."""
|
||||
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
"""Sets prompts for subsequent inference operations."""
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
"""Resets the image and its features to None."""
|
||||
"""Resets the current image and its features, clearing them for subsequent inference."""
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Remove small disconnected regions and holes from segmentation masks.
|
||||
|
||||
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
|
||||
It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||
the number of masks, H is height, and W is width.
|
||||
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||
masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
|
||||
masks, H is height, and W is width.
|
||||
min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
|
||||
this will be removed.
|
||||
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
||||
|
||||
Returns:
|
||||
(tuple([torch.Tensor, List[int]])):
|
||||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
(tuple):
|
||||
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
||||
>>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
|
||||
>>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
|
||||
>>> print(f"Indices of kept masks: {keep}")
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
|
|
@ -480,3 +585,188 @@ class Predictor(BasePredictor):
|
|||
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||
|
||||
|
||||
class SAM2Predictor(Predictor):
|
||||
"""
|
||||
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
||||
|
||||
This class extends the base Predictor class to implement SAM2-specific functionality for image
|
||||
segmentation tasks. It provides methods for model initialization, feature extraction, and
|
||||
prompt-based inference.
|
||||
|
||||
Attributes:
|
||||
_bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
|
||||
model (torch.nn.Module): The loaded SAM2 model.
|
||||
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
||||
features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
|
||||
segment_all (bool): Flag to indicate if all segments should be predicted.
|
||||
prompts (Dict): Dictionary to store various types of prompts for inference.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieves and initializes the SAM2 model.
|
||||
prompt_inference: Performs image segmentation inference based on various prompts.
|
||||
set_image: Preprocesses and sets a single image for inference.
|
||||
get_im_features: Extracts and processes image features using SAM2's image encoder.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
|
||||
>>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
|
||||
"""
|
||||
|
||||
_bb_feat_sizes = [
|
||||
(256, 256),
|
||||
(128, 128),
|
||||
(64, 64),
|
||||
]
|
||||
|
||||
def get_model(self):
|
||||
"""Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def prompt_inference(
|
||||
self,
|
||||
im,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
labels=None,
|
||||
masks=None,
|
||||
multimask_output=False,
|
||||
img_idx=-1,
|
||||
):
|
||||
"""
|
||||
Performs image segmentation inference based on various prompts using SAM2 architecture.
|
||||
|
||||
This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
|
||||
based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
|
||||
multi-object prediction scenarios.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
img_idx (int): Index of the image in the batch to process.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
>>> image = torch.rand(1, 3, 640, 640)
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
|
||||
>>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
|
||||
|
||||
Notes:
|
||||
- The method supports batched inference for multiple objects when points or bboxes are provided.
|
||||
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
|
||||
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
|
||||
|
||||
References:
|
||||
- SAM2 Paper: [Add link to SAM2 paper when available]
|
||||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes = bboxes.view(-1, 2, 2) * r
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
points=points,
|
||||
boxes=None,
|
||||
masks=masks,
|
||||
)
|
||||
# Predict masks
|
||||
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
|
||||
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
|
||||
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
|
||||
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
|
||||
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=batched_mode,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocesses and sets a single image for inference using the SAM2 model.
|
||||
|
||||
This method initializes the model if not already done, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. It supports setting only one image at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is attempted to be set.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor()
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> predictor.set_image(np.array([...])) # Using a numpy array
|
||||
|
||||
Notes:
|
||||
- This method must be called before performing any inference on a new image.
|
||||
- The method caches the extracted features for efficient subsequent inferences on the same image.
|
||||
- Only one image can be set at a time. To process multiple images, call this method for each new image.
|
||||
"""
|
||||
if self.model is None:
|
||||
self.setup_model(model=None)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.get_im_features(im)
|
||||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features from the SAM image encoder for subsequent processing."""
|
||||
backbone_out = self.model.forward_image(im)
|
||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
||||
if self.model.directly_add_no_mem_embed:
|
||||
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
||||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||
][::-1]
|
||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import SAM2
|
||||
from .predict import SAM2Predictor
|
||||
|
||||
__all__ = "SAM2", "SAM2Predictor" # tuple or list
|
||||
|
|
@ -1,156 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder
|
||||
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||||
from .modules.sam2 import SAM2Model
|
||||
|
||||
|
||||
def build_sam2_t(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 7, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[5, 7, 9],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_s(checkpoint=None):
|
||||
"""Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 11, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[7, 10, 13],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_b(checkpoint=None):
|
||||
"""Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=112,
|
||||
encoder_stages=[2, 3, 16, 3],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[12, 16, 20],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_window_spatial_size=[14, 14],
|
||||
encoder_backbone_channel_list=[896, 448, 224, 112],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_l(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=144,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[23, 33, 43],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_sam2(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[7, 15, 23, 31],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
encoder_window_spatial_size=[7, 7],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
checkpoint=None,
|
||||
):
|
||||
"""Builds a SAM2 model with specified architecture parameters and optional checkpoint loading."""
|
||||
image_encoder = ImageEncoder(
|
||||
trunk=Hiera(
|
||||
embed_dim=encoder_embed_dim,
|
||||
num_heads=encoder_num_heads,
|
||||
stages=encoder_stages,
|
||||
global_att_blocks=encoder_global_att_blocks,
|
||||
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
||||
window_spec=encoder_window_spec,
|
||||
),
|
||||
neck=FpnNeck(
|
||||
d_model=256,
|
||||
backbone_channel_list=encoder_backbone_channel_list,
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
scalp=1,
|
||||
)
|
||||
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
||||
memory_encoder = MemoryEncoder(out_dim=64)
|
||||
|
||||
sam2 = SAM2Model(
|
||||
image_encoder=image_encoder,
|
||||
memory_attention=memory_attention,
|
||||
memory_encoder=memory_encoder,
|
||||
num_maskmem=7,
|
||||
image_size=1024,
|
||||
sigmoid_scale_for_mem_enc=20.0,
|
||||
sigmoid_bias_for_mem_enc=-10.0,
|
||||
use_mask_input_as_output_without_sam=True,
|
||||
directly_add_no_mem_embed=True,
|
||||
use_high_res_features_in_sam=True,
|
||||
multimask_output_in_sam=True,
|
||||
iou_prediction_use_sigmoid=True,
|
||||
use_obj_ptrs_in_encoder=True,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
only_obj_ptrs_in_the_past_for_eval=True,
|
||||
pred_obj_scores=True,
|
||||
pred_obj_scores_mlp=True,
|
||||
fixed_no_obj_ptr=True,
|
||||
multimask_output_for_tracking=True,
|
||||
use_multimask_token_for_obj_ptr=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
use_mlp_for_obj_ptr_proj=True,
|
||||
compile_image_encoder=False,
|
||||
sam_mask_decoder_extra_args=dict(
|
||||
dynamic_multimask_via_stability=True,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
),
|
||||
)
|
||||
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)["model"]
|
||||
sam2.load_state_dict(state_dict)
|
||||
sam2.eval()
|
||||
return sam2
|
||||
|
||||
|
||||
sam_model_map = {
|
||||
"sam2_t.pt": build_sam2_t,
|
||||
"sam2_s.pt": build_sam2_s,
|
||||
"sam2_b.pt": build_sam2_b,
|
||||
"sam2_l.pt": build_sam2_l,
|
||||
}
|
||||
|
||||
|
||||
def build_sam2(ckpt="sam_b.pt"):
|
||||
"""Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options."""
|
||||
model_builder = None
|
||||
ckpt = str(ckpt) # to allow Path ckpt types
|
||||
for k in sam_model_map.keys():
|
||||
if ckpt.endswith(k):
|
||||
model_builder = sam_model_map.get(k)
|
||||
|
||||
if not model_builder:
|
||||
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
||||
|
||||
return model_builder(ckpt)
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
SAM2 model interface.
|
||||
|
||||
This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image
|
||||
segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis,
|
||||
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
||||
image distributions and tasks without prior knowledge.
|
||||
|
||||
Key Features:
|
||||
- Promptable segmentation
|
||||
- Real-time performance
|
||||
- Zero-shot transfer capabilities
|
||||
- Trained on SA-1B dataset
|
||||
"""
|
||||
|
||||
from ultralytics.models.sam import SAM
|
||||
|
||||
from .build import build_sam2
|
||||
from .predict import SAM2Predictor
|
||||
|
||||
|
||||
class SAM2(SAM):
|
||||
"""
|
||||
SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2).
|
||||
|
||||
This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation
|
||||
tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The loaded SAM2 model.
|
||||
task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor.
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the SAM2 model with pre-trained weights.
|
||||
_load: Loads specified weights into the SAM2 model.
|
||||
|
||||
Examples:
|
||||
>>> sam2 = SAM2("sam2_b.pt")
|
||||
>>> sam2._load('path/to/sam2_weights.pt')
|
||||
>>> task_map = sam2.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': SAM2Predictor}
|
||||
|
||||
Notes:
|
||||
- Supports .pt and .pth file extensions for model weights.
|
||||
- Offers zero-shot transfer capabilities for new image distributions and tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, model="sam2_b.pt") -> None:
|
||||
"""
|
||||
Initializes the SAM2 model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
|
||||
Examples:
|
||||
>>> sam2 = SAM2("sam2_b.pt")
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Loads the specified weights into the SAM2 model.
|
||||
|
||||
This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading
|
||||
weights from files with .pt or .pth extensions.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file. Should be a file with .pt or .pth extension.
|
||||
task (str | None): Task name. If provided, it may be used to configure model-specific settings.
|
||||
|
||||
Examples:
|
||||
>>> sam2_model = SAM2()
|
||||
>>> sam2_model._load('path/to/sam2_weights.pt')
|
||||
"""
|
||||
self.model = build_sam2(weights)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""
|
||||
Provides a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
(Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding
|
||||
SAM2Predictor class.
|
||||
|
||||
Examples:
|
||||
>>> sam2 = SAM2()
|
||||
>>> task_map = sam2.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': SAM2Predictor}
|
||||
"""
|
||||
return {"segment": {"predictor": SAM2Predictor}}
|
||||
|
|
@ -1 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
|
@ -1,305 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ultralytics.nn.modules import MLP, LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: Type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
use_high_res_features: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
dynamic_multimask_via_stability=False,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MaskDecoder module for predicting instance segmentation masks.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
||||
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for IOU token.
|
||||
num_mask_tokens (int): Total number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for mask tokens.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
obj_score_token (nn.Embedding): Embedding for object score token.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
output_upscaling (nn.Sequential): Upscaling layers for output.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
|
||||
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
|
||||
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
|
||||
iou_prediction_head (MLP): MLP for IOU prediction.
|
||||
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
if self.pred_obj_scores:
|
||||
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.use_high_res_features = use_high_res_features
|
||||
if use_high_res_features:
|
||||
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
|
||||
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
|
||||
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
self.num_mask_tokens,
|
||||
iou_head_depth,
|
||||
sigmoid=iou_prediction_use_sigmoid,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
||||
if pred_obj_scores_mlp:
|
||||
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
||||
|
||||
# When outputting a single mask, optionally we can dynamically fall back to the best
|
||||
# multimask output token if the single mask output token gives low stability scores.
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predicts masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
|
||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
|
||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
repeat_image (bool): Flag to repeat the image embeddings.
|
||||
high_res_features (List[torch.Tensor] | None): Optional high-resolution features.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- masks (torch.Tensor): Batched predicted masks.
|
||||
- iou_pred (torch.Tensor): Batched predictions of mask quality.
|
||||
- sam_tokens_out (torch.Tensor): Batched SAM token for mask output.
|
||||
|
||||
Examples:
|
||||
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
|
||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe,
|
||||
... sparse_prompt_embeddings, dense_prompt_embeddings, True, False)
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
repeat_image=repeat_image,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
masks = masks[:, 1:, :, :]
|
||||
iou_pred = iou_pred[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
masks = masks[:, 0:1, :, :]
|
||||
iou_pred = iou_pred[:, 0:1]
|
||||
|
||||
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
||||
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
||||
else:
|
||||
# Take the mask output token. Here we *always* use the token for single mask output.
|
||||
# At test time, even if we track after 1-click (and using multimask_output=True),
|
||||
# we still take the single mask token here. The rationale is that we always track
|
||||
# after multiple clicks during training, so the past tokens seen during training
|
||||
# are always the single mask token (and we'll let it be the object-memory token).
|
||||
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
||||
|
||||
# Prepare output
|
||||
return masks, iou_pred, sam_tokens_out, object_score_logits
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
repeat_image: bool,
|
||||
high_res_features: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture."""
|
||||
# Concatenate output tokens
|
||||
s = 0
|
||||
if self.pred_obj_scores:
|
||||
output_tokens = torch.cat(
|
||||
[
|
||||
self.obj_score_token.weight,
|
||||
self.iou_token.weight,
|
||||
self.mask_tokens.weight,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
s = 1
|
||||
else:
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
if repeat_image:
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
else:
|
||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, s, :]
|
||||
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
if not self.use_high_res_features:
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
else:
|
||||
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
||||
feat_s0, feat_s1 = high_res_features
|
||||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
||||
|
||||
hyper_in_list: List[torch.Tensor] = []
|
||||
for i in range(self.num_mask_tokens):
|
||||
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
if self.pred_obj_scores:
|
||||
assert s == 1
|
||||
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
||||
else:
|
||||
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
||||
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
||||
|
||||
return masks, iou_pred, mask_tokens_out, object_score_logits
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""Computes mask stability scores based on IoU between upper and lower thresholds."""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
return stability_scores
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
Dynamically selects the most stable mask output based on stability scores and IoU predictions.
|
||||
|
||||
When outputting a single mask, if the stability score from the current single-mask output (based on output token
|
||||
0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with
|
||||
the highest predicted IoU score.
|
||||
|
||||
This is intended to ensure a valid mask for both clicking and tracking.
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device)
|
||||
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
||||
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
||||
|
|
@ -1,332 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.models.sam.modules.encoders import PatchEmbed
|
||||
|
||||
from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
|
||||
|
||||
|
||||
class MemoryEncoder(nn.Module):
|
||||
"""Encodes pixel features and masks into a memory representation for efficient image segmentation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_dim,
|
||||
in_dim=256, # in_dim of pix_feats
|
||||
):
|
||||
"""Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
|
||||
super().__init__()
|
||||
|
||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
||||
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
|
||||
self.out_proj = nn.Identity()
|
||||
if out_dim != in_dim:
|
||||
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pix_feat: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
skip_mask_sigmoid: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Processes pixel features and masks, fusing them to generate encoded memory representations."""
|
||||
if not skip_mask_sigmoid:
|
||||
masks = F.sigmoid(masks)
|
||||
masks = self.mask_downsampler(masks)
|
||||
|
||||
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
|
||||
pix_feat = pix_feat.to(masks.device)
|
||||
|
||||
x = self.pix_feat_proj(pix_feat)
|
||||
x = x + masks
|
||||
x = self.fuser(x)
|
||||
x = self.out_proj(x)
|
||||
|
||||
pos = self.position_encoding(x).to(x.dtype)
|
||||
|
||||
return {"vision_features": x, "vision_pos_enc": [pos]}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trunk: nn.Module,
|
||||
neck: nn.Module,
|
||||
scalp: int = 0,
|
||||
):
|
||||
"""Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
|
||||
super().__init__()
|
||||
self.trunk = trunk
|
||||
self.neck = neck
|
||||
self.scalp = scalp
|
||||
assert (
|
||||
self.trunk.channel_list == self.neck.backbone_channel_list
|
||||
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
||||
|
||||
def forward(self, sample: torch.Tensor):
|
||||
"""Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
|
||||
features, pos = self.neck(self.trunk(sample))
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
||||
|
||||
src = features[-1]
|
||||
output = {
|
||||
"vision_features": src,
|
||||
"vision_pos_enc": pos,
|
||||
"backbone_fpn": features,
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
"""Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
backbone_channel_list: List[int],
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
fpn_interp_model: str = "bilinear",
|
||||
fuse_type: str = "sum",
|
||||
fpn_top_down_levels: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a modified Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
||||
similar to ViT positional embedding interpolation.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimension of the model.
|
||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
||||
kernel_size (int): Kernel size for the convolutional layers.
|
||||
stride (int): Stride for the convolutional layers.
|
||||
padding (int): Padding for the convolutional layers.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
||||
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
|
||||
|
||||
Attributes:
|
||||
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
|
||||
convs (nn.ModuleList): List of convolutional layers for each backbone level.
|
||||
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion.
|
||||
fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
|
||||
|
||||
Examples:
|
||||
>>> backbone_channels = [64, 128, 256, 512]
|
||||
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
||||
>>> print(fpn_neck)
|
||||
"""
|
||||
super().__init__()
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=dim,
|
||||
out_channels=d_model,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
),
|
||||
)
|
||||
|
||||
self.convs.append(current)
|
||||
self.fpn_interp_model = fpn_interp_model
|
||||
assert fuse_type in ["sum", "avg"]
|
||||
self.fuse_type = fuse_type
|
||||
|
||||
# levels to have top-down features in its outputs
|
||||
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
||||
# have top-down propagation, while outputs of level 0 and level 1 have only
|
||||
# lateral features from the same backbone level.
|
||||
if fpn_top_down_levels is None:
|
||||
# default is to have top-down features on all levels
|
||||
fpn_top_down_levels = range(len(self.convs))
|
||||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: List[torch.Tensor]):
|
||||
"""
|
||||
Performs forward pass through the Feature Pyramid Network (FPN) neck.
|
||||
|
||||
Args:
|
||||
xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
|
||||
|
||||
Returns:
|
||||
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
|
||||
- out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
|
||||
- pos: List of positional encodings corresponding to each output feature map.
|
||||
|
||||
Examples:
|
||||
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
|
||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
|
||||
>>> outputs, positions = fpn_neck(inputs)
|
||||
"""
|
||||
out = [None] * len(self.convs)
|
||||
pos = [None] * len(self.convs)
|
||||
assert len(xs) == len(self.convs)
|
||||
# fpn forward pass
|
||||
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
||||
prev_features = None
|
||||
# forward in top-down order (from low to high resolution)
|
||||
n = len(self.convs) - 1
|
||||
for i in range(n, -1, -1):
|
||||
x = xs[i]
|
||||
lateral_features = self.convs[n - i](x)
|
||||
if i in self.fpn_top_down_levels and prev_features is not None:
|
||||
top_down_features = F.interpolate(
|
||||
prev_features.to(dtype=torch.float32),
|
||||
scale_factor=2.0,
|
||||
mode=self.fpn_interp_model,
|
||||
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
||||
antialias=False,
|
||||
)
|
||||
prev_features = lateral_features + top_down_features
|
||||
if self.fuse_type == "avg":
|
||||
prev_features /= 2
|
||||
else:
|
||||
prev_features = lateral_features
|
||||
x_out = prev_features
|
||||
out[i] = x_out
|
||||
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
drop_path_rate: float = 0.0, # stochastic depth
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
||||
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
||||
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
||||
head_mul: float = 2.0, # head_mul factor at stage shift
|
||||
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
||||
# window size per stage, when not using global att.
|
||||
window_spec: Tuple[int, ...] = (
|
||||
8,
|
||||
4,
|
||||
14,
|
||||
7,
|
||||
),
|
||||
# global attn in these blocks
|
||||
global_att_blocks: Tuple[int, ...] = (
|
||||
12,
|
||||
16,
|
||||
20,
|
||||
),
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
"""Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
|
||||
super().__init__()
|
||||
|
||||
assert len(stages) == len(window_spec)
|
||||
self.window_spec = window_spec
|
||||
|
||||
depth = sum(stages)
|
||||
self.q_stride = q_stride
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
||||
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
embed_dim=embed_dim,
|
||||
kernel_size=(7, 7),
|
||||
stride=(4, 4),
|
||||
padding=(3, 3),
|
||||
)
|
||||
# Which blocks have global att?
|
||||
self.global_att_blocks = global_att_blocks
|
||||
|
||||
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
||||
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
cur_stage = 1
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# lags by a block, so first block of
|
||||
# next stage uses an initial window size
|
||||
# of previous stage and final window size of current stage
|
||||
window_size = self.window_spec[cur_stage - 1]
|
||||
|
||||
if self.global_att_blocks is not None:
|
||||
window_size = 0 if i in self.global_att_blocks else window_size
|
||||
|
||||
if i - 1 in self.stage_ends:
|
||||
dim_out = int(embed_dim * dim_mul)
|
||||
num_heads = int(num_heads * head_mul)
|
||||
cur_stage += 1
|
||||
|
||||
block = MultiScaleBlock(
|
||||
dim=embed_dim,
|
||||
dim_out=dim_out,
|
||||
num_heads=num_heads,
|
||||
drop_path=dpr[i],
|
||||
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
embed_dim = dim_out
|
||||
self.blocks.append(block)
|
||||
|
||||
self.channel_list = (
|
||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||
if return_interm_layers
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
|
||||
x = self.patch_embed(x)
|
||||
# x: (B, H, W, C)
|
||||
|
||||
# Add pos embed
|
||||
x = x + self._get_pos_embed(x.shape[1:3])
|
||||
|
||||
outputs = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
||||
feats = x.permute(0, 3, 1, 2)
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
||||
|
|
@ -1,804 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from ultralytics.models.sam.modules.encoders import PromptEncoder
|
||||
from ultralytics.nn.modules import MLP
|
||||
|
||||
from .decoders import MaskDecoder
|
||||
from .sam2_blocks import TwoWayTransformer
|
||||
from .utils import get_1d_sine_pe, select_closest_cond_frames
|
||||
|
||||
# a large negative value as a placeholder score for missing objects
|
||||
NO_OBJ_SCORE = -1024.0
|
||||
|
||||
|
||||
class SAM2Model(torch.nn.Module):
|
||||
"""SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities."""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_encoder,
|
||||
memory_attention,
|
||||
memory_encoder,
|
||||
num_maskmem=7, # default 1 input frame + 6 previous frames
|
||||
image_size=512,
|
||||
backbone_stride=16, # stride of the image backbone output
|
||||
sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob
|
||||
sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob
|
||||
# During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
|
||||
binarize_mask_from_pts_for_mem_enc=False,
|
||||
use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
|
||||
# The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
|
||||
# we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
|
||||
# a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
|
||||
max_cond_frames_in_attn=-1,
|
||||
# on the first frame, whether to directly add the no-memory embedding to the image feature
|
||||
# (instead of using the transformer encoder)
|
||||
directly_add_no_mem_embed=False,
|
||||
# whether to use high-resolution feature maps in the SAM mask decoder
|
||||
use_high_res_features_in_sam=False,
|
||||
# whether to output multiple (3) masks for the first click on initial conditioning frames
|
||||
multimask_output_in_sam=False,
|
||||
# the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
|
||||
# default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
|
||||
multimask_min_pt_num=1,
|
||||
multimask_max_pt_num=1,
|
||||
# whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
|
||||
multimask_output_for_tracking=False,
|
||||
# Whether to use multimask tokens for obj ptr; Only relevant when both
|
||||
# use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
# whether to use sigmoid to restrict ious prediction to [0-1]
|
||||
iou_prediction_use_sigmoid=False,
|
||||
# The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
|
||||
# For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
|
||||
# (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
|
||||
memory_temporal_stride_for_eval=1,
|
||||
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
||||
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
||||
add_all_frames_to_correct_as_cond=False,
|
||||
# whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
|
||||
non_overlap_masks_for_mem_enc=False,
|
||||
# whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
||||
use_obj_ptrs_in_encoder=False,
|
||||
# the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
|
||||
max_obj_ptrs_in_encoder=16,
|
||||
# whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
# whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
|
||||
# with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
|
||||
proj_tpos_enc_in_obj_ptrs=False,
|
||||
# whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
|
||||
# (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
|
||||
only_obj_ptrs_in_the_past_for_eval=False,
|
||||
# Whether to predict if there is an object in the frame
|
||||
pred_obj_scores: bool = False,
|
||||
# Whether to use an MLP to predict object scores
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
# Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
|
||||
# Whether to have a fixed no obj pointer when there is no object present
|
||||
# or to use it as an additive embedding with obj_ptr produced by decoder
|
||||
fixed_no_obj_ptr: bool = False,
|
||||
# Soft no object, i.e. mix in no_obj_ptr softly,
|
||||
# hope to make recovery easier if there is a mistake and mitigate accumulation of errors
|
||||
soft_no_obj_ptr: bool = False,
|
||||
use_mlp_for_obj_ptr_proj: bool = False,
|
||||
# extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
|
||||
sam_mask_decoder_extra_args=None,
|
||||
compile_image_encoder: bool = False,
|
||||
):
|
||||
"""Initializes SAM2Model model with image encoder, memory attention, and memory encoder components."""
|
||||
super().__init__()
|
||||
|
||||
# Part 1: the image backbone
|
||||
self.image_encoder = image_encoder
|
||||
# Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
|
||||
self.use_high_res_features_in_sam = use_high_res_features_in_sam
|
||||
self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
|
||||
self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
|
||||
self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
|
||||
if use_obj_ptrs_in_encoder:
|
||||
# A conv layer to downsample the mask prompt to stride 4 (the same stride as
|
||||
# low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
|
||||
# so that it can be fed into the SAM mask decoder to generate a pointer.
|
||||
self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
|
||||
self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
|
||||
if proj_tpos_enc_in_obj_ptrs:
|
||||
assert add_tpos_enc_to_obj_ptrs # these options need to be used together
|
||||
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
|
||||
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
|
||||
|
||||
# Part 2: memory attention to condition current frame's visual features
|
||||
# with memories (and obj ptrs) from past frames
|
||||
self.memory_attention = memory_attention
|
||||
self.hidden_dim = memory_attention.d_model
|
||||
|
||||
# Part 3: memory encoder for the previous frame's outputs
|
||||
self.memory_encoder = memory_encoder
|
||||
self.mem_dim = self.hidden_dim
|
||||
if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"):
|
||||
# if there is compression of memories along channel dim
|
||||
self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
|
||||
self.num_maskmem = num_maskmem # Number of memories accessible
|
||||
# Temporal encoding of the memories
|
||||
self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim))
|
||||
trunc_normal_(self.maskmem_tpos_enc, std=0.02)
|
||||
# a single token to indicate no memory embedding from previous frames
|
||||
self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
||||
self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
|
||||
trunc_normal_(self.no_mem_embed, std=0.02)
|
||||
trunc_normal_(self.no_mem_pos_enc, std=0.02)
|
||||
self.directly_add_no_mem_embed = directly_add_no_mem_embed
|
||||
# Apply sigmoid to the output raw mask logits (to turn them from
|
||||
# range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
|
||||
self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
|
||||
self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
|
||||
self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
|
||||
self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
|
||||
self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
|
||||
# On frames with mask input, whether to directly output the input mask without
|
||||
# using a SAM prompt encoder + mask decoder
|
||||
self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
|
||||
self.multimask_output_in_sam = multimask_output_in_sam
|
||||
self.multimask_min_pt_num = multimask_min_pt_num
|
||||
self.multimask_max_pt_num = multimask_max_pt_num
|
||||
self.multimask_output_for_tracking = multimask_output_for_tracking
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
|
||||
|
||||
# Part 4: SAM-style prompt encoder (for both mask and point inputs)
|
||||
# and SAM-style mask decoder for the final mask output
|
||||
self.image_size = image_size
|
||||
self.backbone_stride = backbone_stride
|
||||
self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
self.pred_obj_scores_mlp = pred_obj_scores_mlp
|
||||
self.fixed_no_obj_ptr = fixed_no_obj_ptr
|
||||
self.soft_no_obj_ptr = soft_no_obj_ptr
|
||||
if self.fixed_no_obj_ptr:
|
||||
assert self.pred_obj_scores
|
||||
assert self.use_obj_ptrs_in_encoder
|
||||
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
|
||||
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
|
||||
trunc_normal_(self.no_obj_ptr, std=0.02)
|
||||
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
|
||||
|
||||
self._build_sam_heads()
|
||||
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
||||
self.max_cond_frames_in_attn = max_cond_frames_in_attn
|
||||
|
||||
# Model compilation
|
||||
if compile_image_encoder:
|
||||
# Compile the forward function (not the full module) to allow loading checkpoints.
|
||||
print("Image encoder compilation is enabled. First forward pass will be slow.")
|
||||
self.image_encoder.forward = torch.compile(
|
||||
self.image_encoder.forward,
|
||||
mode="max-autotune",
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Returns the device on which the model's parameters are stored."""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Processes input frames and prompts to generate object masks and scores in video sequences."""
|
||||
raise NotImplementedError(
|
||||
"Please use the corresponding methods in SAM2VideoPredictor for inference."
|
||||
"See notebooks/video_predictor_example.ipynb for an example."
|
||||
)
|
||||
|
||||
def _build_sam_heads(self):
|
||||
"""Builds SAM-style prompt encoder and mask decoder for image segmentation tasks."""
|
||||
self.sam_prompt_embed_dim = self.hidden_dim
|
||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
||||
|
||||
# build PromptEncoder and MaskDecoder from SAM
|
||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
self.sam_prompt_encoder = PromptEncoder(
|
||||
embed_dim=self.sam_prompt_embed_dim,
|
||||
image_embedding_size=(
|
||||
self.sam_image_embedding_size,
|
||||
self.sam_image_embedding_size,
|
||||
),
|
||||
input_image_size=(self.image_size, self.image_size),
|
||||
mask_in_chans=16,
|
||||
)
|
||||
self.sam_mask_decoder = MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=self.sam_prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=self.sam_prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
use_high_res_features=self.use_high_res_features_in_sam,
|
||||
iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
|
||||
pred_obj_scores=self.pred_obj_scores,
|
||||
pred_obj_scores_mlp=self.pred_obj_scores_mlp,
|
||||
use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
|
||||
**(self.sam_mask_decoder_extra_args or {}),
|
||||
)
|
||||
if self.use_obj_ptrs_in_encoder:
|
||||
# a linear projection on SAM output tokens to turn them into object pointers
|
||||
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
|
||||
if self.use_mlp_for_obj_ptr_proj:
|
||||
self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3)
|
||||
else:
|
||||
self.obj_ptr_proj = torch.nn.Identity()
|
||||
if self.proj_tpos_enc_in_obj_ptrs:
|
||||
# a linear projection on temporal positional encoding in object pointers to
|
||||
# avoid potential interference with spatial positional encoding
|
||||
self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
|
||||
else:
|
||||
self.obj_ptr_tpos_proj = torch.nn.Identity()
|
||||
|
||||
def _forward_sam_heads(
|
||||
self,
|
||||
backbone_features,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
high_res_features=None,
|
||||
multimask_output=False,
|
||||
):
|
||||
"""
|
||||
Forward SAM prompt encoders and mask heads.
|
||||
|
||||
Args:
|
||||
backbone_features (torch.Tensor): Image features with shape (B, C, H, W).
|
||||
point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts.
|
||||
'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute
|
||||
pixel-unit coordinates in (x, y) format for P input points.
|
||||
'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks,
|
||||
0 means negative clicks, and -1 means padding.
|
||||
mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the
|
||||
same spatial size as the image.
|
||||
high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes
|
||||
(B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps
|
||||
for SAM decoder.
|
||||
multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False,
|
||||
output only 1 mask and its IoU estimate.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]):
|
||||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||
|
||||
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
||||
|
||||
Examples:
|
||||
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
||||
>>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])}
|
||||
>>> mask_inputs = torch.rand(1, 1, 512, 512)
|
||||
>>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs)
|
||||
>>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results
|
||||
"""
|
||||
B = backbone_features.size(0)
|
||||
device = backbone_features.device
|
||||
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
||||
assert backbone_features.size(2) == self.sam_image_embedding_size
|
||||
assert backbone_features.size(3) == self.sam_image_embedding_size
|
||||
|
||||
# a) Handle point prompts
|
||||
if point_inputs is not None:
|
||||
sam_point_coords = point_inputs["point_coords"]
|
||||
sam_point_labels = point_inputs["point_labels"]
|
||||
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
||||
else:
|
||||
# If no points are provide, pad with an empty point (with label -1)
|
||||
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
||||
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
||||
|
||||
# b) Handle mask prompts
|
||||
if mask_inputs is not None:
|
||||
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
||||
# and feed it as a dense mask prompt into the SAM mask encoder
|
||||
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
||||
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
||||
sam_mask_prompt = F.interpolate(
|
||||
mask_inputs.float(),
|
||||
size=self.sam_prompt_encoder.mask_input_size,
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True, # use antialias for downsampling
|
||||
)
|
||||
else:
|
||||
sam_mask_prompt = mask_inputs
|
||||
else:
|
||||
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
||||
# a learned `no_mask_embed` to indicate no mask input in this case).
|
||||
sam_mask_prompt = None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
||||
points=(sam_point_coords, sam_point_labels),
|
||||
boxes=None,
|
||||
masks=sam_mask_prompt,
|
||||
)
|
||||
(
|
||||
low_res_multimasks,
|
||||
ious,
|
||||
sam_output_tokens,
|
||||
object_score_logits,
|
||||
) = self.sam_mask_decoder(
|
||||
image_embeddings=backbone_features,
|
||||
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=False, # the image is already batched
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
is_obj_appearing = object_score_logits > 0
|
||||
|
||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||
# consistent with the actual mask prediction
|
||||
low_res_multimasks = torch.where(
|
||||
is_obj_appearing[:, None, None],
|
||||
low_res_multimasks,
|
||||
NO_OBJ_SCORE,
|
||||
)
|
||||
|
||||
# convert masks from possibly bfloat16 (or float16) to float32
|
||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||
low_res_multimasks = low_res_multimasks.float()
|
||||
high_res_multimasks = F.interpolate(
|
||||
low_res_multimasks,
|
||||
size=(self.image_size, self.image_size),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
sam_output_token = sam_output_tokens[:, 0]
|
||||
if multimask_output:
|
||||
# take the best mask prediction (with the highest IoU estimation)
|
||||
best_iou_inds = torch.argmax(ious, dim=-1)
|
||||
batch_inds = torch.arange(B, device=device)
|
||||
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
||||
if sam_output_tokens.size(1) > 1:
|
||||
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
||||
else:
|
||||
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
||||
|
||||
# Extract object pointer from the SAM output token (with occlusion handling)
|
||||
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
||||
if self.pred_obj_scores:
|
||||
# Allow *soft* no obj ptr, unlike for masks
|
||||
if self.soft_no_obj_ptr:
|
||||
# Only hard possible with gt
|
||||
assert not self.teacher_force_obj_scores_for_mem
|
||||
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
||||
else:
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
|
||||
if self.fixed_no_obj_ptr:
|
||||
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||
|
||||
return (
|
||||
low_res_multimasks,
|
||||
high_res_multimasks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
)
|
||||
|
||||
def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
|
||||
"""Processes mask inputs to generate output mask logits and object pointers without using SAM."""
|
||||
# Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
|
||||
out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05
|
||||
mask_inputs_float = mask_inputs.float()
|
||||
high_res_masks = mask_inputs_float * out_scale + out_bias
|
||||
low_res_masks = F.interpolate(
|
||||
high_res_masks,
|
||||
size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
|
||||
align_corners=False,
|
||||
mode="bilinear",
|
||||
antialias=True, # use antialias for downsampling
|
||||
)
|
||||
# a dummy IoU prediction of all 1's under mask input
|
||||
ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
|
||||
if not self.use_obj_ptrs_in_encoder:
|
||||
# all zeros as a dummy object pointer (of shape [B, C])
|
||||
obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device)
|
||||
else:
|
||||
# produce an object pointer using the SAM decoder from the mask input
|
||||
_, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
|
||||
backbone_features=backbone_features,
|
||||
mask_inputs=self.mask_downsample(mask_inputs_float),
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
|
||||
# Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
|
||||
# on the object_scores from the SAM decoder.
|
||||
is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
|
||||
is_obj_appearing = is_obj_appearing[..., None]
|
||||
lambda_is_obj_appearing = is_obj_appearing.float()
|
||||
object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
|
||||
if self.pred_obj_scores:
|
||||
if self.fixed_no_obj_ptr:
|
||||
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
||||
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
||||
|
||||
return (
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
ious,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
)
|
||||
|
||||
def forward_image(self, img_batch: torch.Tensor):
|
||||
"""Process image batch through encoder to extract multi-level features for SAM model."""
|
||||
backbone_out = self.image_encoder(img_batch)
|
||||
if self.use_high_res_features_in_sam:
|
||||
# precompute projected level 0 and level 1 features in SAM decoder
|
||||
# to avoid running it again on every SAM click
|
||||
backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
|
||||
backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
|
||||
return backbone_out
|
||||
|
||||
def _prepare_backbone_features(self, backbone_out):
|
||||
"""Prepare and flatten visual features from the image backbone output."""
|
||||
backbone_out = backbone_out.copy()
|
||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||
|
||||
feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
|
||||
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
|
||||
|
||||
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
|
||||
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
|
||||
|
||||
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
|
||||
|
||||
def _prepare_memory_conditioned_features(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
||||
):
|
||||
"""Prepares memory-conditioned features by fusing current frame's visual features with previous memories."""
|
||||
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||
C = self.hidden_dim
|
||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||
device = current_vision_feats[-1].device
|
||||
# The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
|
||||
# In this case, we skip the fusion with any memory.
|
||||
if self.num_maskmem == 0: # Disable memory and skip fusion
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat
|
||||
|
||||
num_obj_ptr_tokens = 0
|
||||
# Step 1: condition the visual features of the current frame on previous memories
|
||||
if not is_init_cond_frame:
|
||||
# Retrieve the memories encoded with the maskmem backbone
|
||||
to_cat_memory, to_cat_memory_pos_embed = [], []
|
||||
# Add conditioning frames's output first (all cond frames have t_pos=0 for
|
||||
# when getting temporal positional embedding below)
|
||||
assert len(output_dict["cond_frame_outputs"]) > 0
|
||||
# Select a maximum number of temporally closest cond frames for cross attention
|
||||
cond_outputs = output_dict["cond_frame_outputs"]
|
||||
selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
|
||||
frame_idx, cond_outputs, self.max_cond_frames_in_attn
|
||||
)
|
||||
t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
|
||||
# Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
|
||||
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
|
||||
# We also allow taking the memory frame non-consecutively (with r>1), in which case
|
||||
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
|
||||
r = self.memory_temporal_stride_for_eval
|
||||
for t_pos in range(1, self.num_maskmem):
|
||||
t_rel = self.num_maskmem - t_pos # how many frames before current frame
|
||||
if t_rel == 1:
|
||||
# for t_rel == 1, we take the last frame (regardless of r)
|
||||
if not track_in_reverse:
|
||||
# the frame immediately before this frame (i.e. frame_idx - 1)
|
||||
prev_frame_idx = frame_idx - t_rel
|
||||
else:
|
||||
# the frame immediately after this frame (i.e. frame_idx + 1)
|
||||
prev_frame_idx = frame_idx + t_rel
|
||||
else:
|
||||
# for t_rel >= 2, we take the memory frame from every r-th frames
|
||||
if not track_in_reverse:
|
||||
# first find the nearest frame among every r-th frames before this frame
|
||||
# for r=1, this would be (frame_idx - 2)
|
||||
prev_frame_idx = ((frame_idx - 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
|
||||
else:
|
||||
# first find the nearest frame among every r-th frames after this frame
|
||||
# for r=1, this would be (frame_idx + 2)
|
||||
prev_frame_idx = -(-(frame_idx + 2) // r) * r
|
||||
# then seek further among every r-th frames
|
||||
prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
|
||||
out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
|
||||
if out is None:
|
||||
# If an unselected conditioning frame is among the last (self.num_maskmem - 1)
|
||||
# frames, we still attend to it as if it's a non-conditioning frame.
|
||||
out = unselected_cond_outputs.get(prev_frame_idx, None)
|
||||
t_pos_and_prevs.append((t_pos, out))
|
||||
|
||||
for t_pos, prev in t_pos_and_prevs:
|
||||
if prev is None:
|
||||
continue # skip padding frames
|
||||
# "maskmem_features" might have been offloaded to CPU in demo use cases,
|
||||
# so we load it back to GPU (it's a no-op if it's already on GPU).
|
||||
feats = prev["maskmem_features"].cuda(non_blocking=True)
|
||||
to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
|
||||
# Spatial positional encoding (it might have been offloaded to CPU in eval)
|
||||
maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
|
||||
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
|
||||
# Temporal positional encoding
|
||||
maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
|
||||
to_cat_memory_pos_embed.append(maskmem_enc)
|
||||
|
||||
# Construct the list of past object pointers
|
||||
if self.use_obj_ptrs_in_encoder:
|
||||
max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
|
||||
# First add those object pointers from selected conditioning frames
|
||||
# (optionally, only include object pointers in the past during evaluation)
|
||||
if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
|
||||
ptr_cond_outputs = {
|
||||
t: out
|
||||
for t, out in selected_cond_outputs.items()
|
||||
if (t >= frame_idx if track_in_reverse else t <= frame_idx)
|
||||
}
|
||||
else:
|
||||
ptr_cond_outputs = selected_cond_outputs
|
||||
pos_and_ptrs = [
|
||||
# Temporal pos encoding contains how far away each pointer is from current frame
|
||||
(abs(frame_idx - t), out["obj_ptr"])
|
||||
for t, out in ptr_cond_outputs.items()
|
||||
]
|
||||
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
|
||||
for t_diff in range(1, max_obj_ptrs_in_encoder):
|
||||
t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
|
||||
if t < 0 or (num_frames is not None and t >= num_frames):
|
||||
break
|
||||
out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None))
|
||||
if out is not None:
|
||||
pos_and_ptrs.append((t_diff, out["obj_ptr"]))
|
||||
# If we have at least one object pointer, add them to the across attention
|
||||
if len(pos_and_ptrs) > 0:
|
||||
pos_list, ptrs_list = zip(*pos_and_ptrs)
|
||||
# stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
|
||||
obj_ptrs = torch.stack(ptrs_list, dim=0)
|
||||
# a temporal positional embedding based on how far each object pointer is from
|
||||
# the current frame (sine embedding normalized by the max pointer num).
|
||||
if self.add_tpos_enc_to_obj_ptrs:
|
||||
t_diff_max = max_obj_ptrs_in_encoder - 1
|
||||
tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
|
||||
obj_pos = torch.tensor(pos_list, device=device)
|
||||
obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
|
||||
obj_pos = self.obj_ptr_tpos_proj(obj_pos)
|
||||
obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
|
||||
else:
|
||||
obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
|
||||
if self.mem_dim < C:
|
||||
# split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
|
||||
obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim)
|
||||
obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
|
||||
obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
|
||||
to_cat_memory.append(obj_ptrs)
|
||||
to_cat_memory_pos_embed.append(obj_pos)
|
||||
num_obj_ptr_tokens = obj_ptrs.shape[0]
|
||||
else:
|
||||
num_obj_ptr_tokens = 0
|
||||
else:
|
||||
# for initial conditioning frames, encode them without using any previous memory
|
||||
if self.directly_add_no_mem_embed:
|
||||
# directly add no-mem embedding (instead of using the transformer encoder)
|
||||
pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
# Use a dummy token on the first frame (to avoid empty memory input to transformer encoder)
|
||||
to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
|
||||
to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
|
||||
|
||||
# Step 2: Concatenate the memories and forward through the transformer encoder
|
||||
memory = torch.cat(to_cat_memory, dim=0)
|
||||
memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
|
||||
|
||||
pix_feat_with_mem = self.memory_attention(
|
||||
curr=current_vision_feats,
|
||||
curr_pos=current_vision_pos_embeds,
|
||||
memory=memory,
|
||||
memory_pos=memory_pos_embed,
|
||||
num_obj_ptr_tokens=num_obj_ptr_tokens,
|
||||
)
|
||||
# reshape the output (HW)BC => BCHW
|
||||
pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
|
||||
return pix_feat_with_mem
|
||||
|
||||
def _encode_new_memory(
|
||||
self,
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
pred_masks_high_res,
|
||||
is_mask_from_pts,
|
||||
):
|
||||
"""Encodes the current frame's features and predicted masks into a new memory representation."""
|
||||
B = current_vision_feats[-1].size(1) # batch size on this frame
|
||||
C = self.hidden_dim
|
||||
H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size
|
||||
# top-level feature, (HW)BC => BCHW
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
|
||||
if self.non_overlap_masks_for_mem_enc and not self.training:
|
||||
# optionally, apply non-overlapping constraints to the masks (it's applied
|
||||
# in the batch dimension and should only be used during eval, where all
|
||||
# the objects come from the same video under batch size 1).
|
||||
pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res)
|
||||
# scale the raw mask logits with a temperature before applying sigmoid
|
||||
binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
|
||||
if binarize and not self.training:
|
||||
mask_for_mem = (pred_masks_high_res > 0).float()
|
||||
else:
|
||||
# apply sigmoid on the raw mask logits to turn them into range (0, 1)
|
||||
mask_for_mem = torch.sigmoid(pred_masks_high_res)
|
||||
# apply scale and bias terms to the sigmoid probabilities
|
||||
if self.sigmoid_scale_for_mem_enc != 1.0:
|
||||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(
|
||||
pix_feat,
|
||||
mask_for_mem,
|
||||
skip_mask_sigmoid=True, # sigmoid already applied
|
||||
)
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
|
||||
return maskmem_features, maskmem_pos_enc
|
||||
|
||||
def track_step(
|
||||
self,
|
||||
frame_idx,
|
||||
is_init_cond_frame,
|
||||
current_vision_feats,
|
||||
current_vision_pos_embeds,
|
||||
feat_sizes,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
output_dict,
|
||||
num_frames,
|
||||
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
||||
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
||||
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
||||
# in demo we might call `track_step` multiple times for each user click,
|
||||
# and only encode the memory when the user finalizes their clicks. And in ablation
|
||||
# settings like SAM training on static images, we don't need the memory encoder.
|
||||
run_mem_encoder=True,
|
||||
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
|
||||
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
||||
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
||||
if len(current_vision_feats) > 1:
|
||||
high_res_features = [
|
||||
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
||||
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
||||
]
|
||||
else:
|
||||
high_res_features = None
|
||||
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
||||
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
||||
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
||||
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
||||
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
||||
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
|
||||
else:
|
||||
# fused the visual feature with previous memory features in the memory bank
|
||||
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats[-1:],
|
||||
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
||||
feat_sizes=feat_sizes[-1:],
|
||||
output_dict=output_dict,
|
||||
num_frames=num_frames,
|
||||
track_in_reverse=track_in_reverse,
|
||||
)
|
||||
# apply SAM-style segmentation head
|
||||
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
||||
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
||||
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
||||
if prev_sam_mask_logits is not None:
|
||||
assert point_inputs is not None and mask_inputs is None
|
||||
mask_inputs = prev_sam_mask_logits
|
||||
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
||||
sam_outputs = self._forward_sam_heads(
|
||||
backbone_features=pix_feat_with_mem,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
high_res_features=high_res_features,
|
||||
multimask_output=multimask_output,
|
||||
)
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
_,
|
||||
) = sam_outputs
|
||||
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
current_out["obj_ptr"] = obj_ptr
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
if run_mem_encoder and self.num_maskmem > 0:
|
||||
high_res_masks_for_mem_enc = high_res_masks
|
||||
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks_for_mem_enc,
|
||||
is_mask_from_pts=(point_inputs is not None),
|
||||
)
|
||||
current_out["maskmem_features"] = maskmem_features
|
||||
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
current_out["maskmem_features"] = None
|
||||
current_out["maskmem_pos_enc"] = None
|
||||
|
||||
return current_out
|
||||
|
||||
def _use_multimask(self, is_init_cond_frame, point_inputs):
|
||||
"""Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs."""
|
||||
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
|
||||
multimask_output = (
|
||||
self.multimask_output_in_sam
|
||||
and (is_init_cond_frame or self.multimask_output_for_tracking)
|
||||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||
)
|
||||
return multimask_output
|
||||
|
||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
||||
"""Applies non-overlapping constraints to object masks, keeping highest scoring object at each location."""
|
||||
batch_size = pred_masks.size(0)
|
||||
if batch_size == 1:
|
||||
return pred_masks
|
||||
|
||||
device = pred_masks.device
|
||||
# "max_obj_inds": object index of the object with the highest score at each location
|
||||
max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
|
||||
# "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
|
||||
batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
|
||||
keep = max_obj_inds == batch_obj_inds
|
||||
# suppress overlapping regions' scores below -10.0 so that the foreground regions
|
||||
# don't overlap (here sigmoid(-10.0)=4.5398e-05)
|
||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
|
@ -1,715 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import copy
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ultralytics.models.sam.modules.transformer import (
|
||||
Attention,
|
||||
)
|
||||
from ultralytics.models.sam.modules.transformer import (
|
||||
TwoWayAttentionBlock as SAMTwoWayAttentionBlock,
|
||||
)
|
||||
from ultralytics.models.sam.modules.transformer import (
|
||||
TwoWayTransformer as SAMTwoWayTransformer,
|
||||
)
|
||||
from ultralytics.nn.modules import MLP, LayerNorm2d
|
||||
|
||||
from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Implements stochastic depth regularization for neural networks during training."""
|
||||
|
||||
def __init__(self, drop_prob=0.0, scale_by_keep=True):
|
||||
"""Initialize DropPath module with specified drop probability and scaling option."""
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies stochastic depth to input tensor during training, with optional scaling."""
|
||||
if self.drop_prob == 0.0 or not self.training:
|
||||
return x
|
||||
keep_prob = 1 - self.drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and self.scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class MaskDownSampler(nn.Module):
|
||||
"""Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim=256,
|
||||
kernel_size=4,
|
||||
stride=4,
|
||||
padding=0,
|
||||
total_stride=16,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""Initializes a mask downsampler module for progressive downsampling and channel expansion."""
|
||||
super().__init__()
|
||||
num_layers = int(math.log2(total_stride) // math.log2(stride))
|
||||
assert stride**num_layers == total_stride
|
||||
self.encoder = nn.Sequential()
|
||||
mask_in_chans, mask_out_chans = 1, 1
|
||||
for _ in range(num_layers):
|
||||
mask_out_chans = mask_in_chans * (stride**2)
|
||||
self.encoder.append(
|
||||
nn.Conv2d(
|
||||
mask_in_chans,
|
||||
mask_out_chans,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
)
|
||||
)
|
||||
self.encoder.append(LayerNorm2d(mask_out_chans))
|
||||
self.encoder.append(activation())
|
||||
mask_in_chans = mask_out_chans
|
||||
|
||||
self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
|
||||
|
||||
def forward(self, x):
|
||||
"""Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d."""
|
||||
return self.encoder(x)
|
||||
|
||||
|
||||
# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
|
||||
class CXBlock(nn.Module):
|
||||
"""
|
||||
ConvNeXt Block for efficient feature extraction in convolutional neural networks.
|
||||
|
||||
This block implements a modified version of the ConvNeXt architecture, offering two equivalent
|
||||
implementations for improved performance and flexibility.
|
||||
|
||||
Attributes:
|
||||
dwconv (nn.Conv2d): Depthwise convolution layer.
|
||||
norm (LayerNorm2d): Layer normalization applied to channels.
|
||||
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
|
||||
act (nn.GELU): GELU activation function.
|
||||
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
|
||||
gamma (nn.Parameter | None): Learnable scale parameter for layer scaling.
|
||||
drop_path (nn.Module): DropPath layer for stochastic depth regularization.
|
||||
|
||||
Methods:
|
||||
forward: Processes the input tensor through the ConvNeXt block.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> x = torch.randn(1, 64, 56, 56)
|
||||
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
|
||||
>>> output = block(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
kernel_size=7,
|
||||
padding=3,
|
||||
drop_path=0.0,
|
||||
layer_scale_init_value=1e-6,
|
||||
use_dwconv=True,
|
||||
):
|
||||
"""
|
||||
Initialize a ConvNeXt Block.
|
||||
|
||||
This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization,
|
||||
pointwise convolutions, and GELU activation.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
kernel_size (int): Size of the convolutional kernel. Default is 7.
|
||||
padding (int): Padding size for the convolution. Default is 3.
|
||||
drop_path (float): Stochastic depth rate. Default is 0.0.
|
||||
layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6.
|
||||
use_dwconv (bool): Whether to use depthwise convolution. Default is True.
|
||||
|
||||
Attributes:
|
||||
dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer.
|
||||
norm (LayerNorm2d): Layer normalization applied to the output of dwconv.
|
||||
pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer.
|
||||
act (nn.GELU): GELU activation function.
|
||||
pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer.
|
||||
gamma (nn.Parameter | None): Learnable scale parameter for the residual path.
|
||||
|
||||
Examples:
|
||||
>>> block = CXBlock(dim=64, kernel_size=7, padding=3)
|
||||
>>> x = torch.randn(1, 64, 32, 32)
|
||||
>>> output = block(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 32, 32])
|
||||
"""
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv2d(
|
||||
dim,
|
||||
dim,
|
||||
kernel_size=kernel_size,
|
||||
padding=padding,
|
||||
groups=dim if use_dwconv else 1,
|
||||
) # depthwise conv
|
||||
self.norm = LayerNorm2d(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Linear(4 * dim, dim)
|
||||
self.gamma = (
|
||||
nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
||||
if layer_scale_init_value > 0
|
||||
else None
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies ConvNeXt block operations to input tensor, including convolutions and residual connection."""
|
||||
input = x
|
||||
x = self.dwconv(x)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||
|
||||
x = input + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class Fuser(nn.Module):
|
||||
"""
|
||||
A module for fusing features through multiple layers of a neural network.
|
||||
|
||||
This class applies a series of identical layers to an input tensor, optionally projecting the input first.
|
||||
|
||||
Attributes:
|
||||
proj (nn.Module): An optional input projection layer. Identity if no projection is needed.
|
||||
layers (nn.ModuleList): A list of identical layers to be applied sequentially.
|
||||
|
||||
Methods:
|
||||
forward: Applies the fuser to an input tensor.
|
||||
|
||||
Examples:
|
||||
>>> layer = CXBlock(dim=256)
|
||||
>>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True)
|
||||
>>> x = torch.randn(1, 256, 32, 32)
|
||||
>>> output = fuser(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 256, 32, 32])
|
||||
"""
|
||||
|
||||
def __init__(self, layer, num_layers, dim=None, input_projection=False):
|
||||
"""
|
||||
Initializes the Fuser module.
|
||||
|
||||
This module creates a sequence of identical layers and optionally applies an input projection.
|
||||
|
||||
Args:
|
||||
layer (nn.Module): The layer to be replicated in the fuser.
|
||||
num_layers (int): The number of times to replicate the layer.
|
||||
dim (int | None): The dimension for input projection, if used.
|
||||
input_projection (bool): Whether to use input projection.
|
||||
|
||||
Attributes:
|
||||
proj (nn.Module): The input projection layer, or nn.Identity if not used.
|
||||
layers (nn.ModuleList): A list of replicated layers.
|
||||
|
||||
Examples:
|
||||
>>> layer = nn.Linear(64, 64)
|
||||
>>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True)
|
||||
>>> input_tensor = torch.randn(1, 64)
|
||||
>>> output = fuser(input_tensor)
|
||||
"""
|
||||
super().__init__()
|
||||
self.proj = nn.Identity()
|
||||
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
||||
|
||||
if input_projection:
|
||||
assert dim is not None
|
||||
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
"""Applies a series of layers to the input tensor, optionally projecting it first."""
|
||||
x = self.proj(x)
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(SAMTwoWayAttentionBlock):
|
||||
"""
|
||||
A two-way attention block for performing self-attention and cross-attention in both directions.
|
||||
|
||||
This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on
|
||||
sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and
|
||||
cross-attention from dense to sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): Self-attention layer for queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization after the first attention block.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization after the second attention block.
|
||||
mlp (MLP): MLP block for transforming query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization after the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization after the third attention block.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer.
|
||||
|
||||
Methods:
|
||||
forward: Processes input through the attention blocks and MLP.
|
||||
|
||||
Examples:
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8)
|
||||
>>> sparse_input = torch.randn(1, 100, 256)
|
||||
>>> dense_input = torch.randn(1, 256, 16, 16)
|
||||
>>> sparse_output, dense_output = block(sparse_input, dense_input)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions.
|
||||
|
||||
This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs
|
||||
to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): The channel dimension of the embeddings.
|
||||
num_heads (int): The number of heads in the attention layers.
|
||||
mlp_dim (int): The hidden dimension of the MLP block.
|
||||
activation (Type[nn.Module]): The activation function of the MLP block.
|
||||
attention_downsample_rate (int): The downsample rate for attention computations.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): The self-attention layer for the queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization following the first attention block.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization following the second attention block.
|
||||
mlp (MLP): MLP block that transforms the query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization following the MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization following the third attention block.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer.
|
||||
|
||||
Examples:
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> sparse_inputs = torch.randn(1, 100, 256)
|
||||
>>> dense_inputs = torch.randn(1, 256, 32, 32)
|
||||
>>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs)
|
||||
"""
|
||||
super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe)
|
||||
self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation)
|
||||
|
||||
|
||||
class TwoWayTransformer(SAMTwoWayTransformer):
|
||||
"""
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer decoder that attends to an input image using queries with
|
||||
supplied positional embeddings. It is particularly useful for tasks like object detection, image
|
||||
segmentation, and point cloud processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Methods:
|
||||
forward: Processes input image embeddings and query embeddings through the transformer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 64, 64)
|
||||
>>> query_embedding = torch.randn(1, 100, 256)
|
||||
>>> output = transformer(image_embedding, query_embedding)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: Type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a TwoWayTransformer instance.
|
||||
|
||||
This transformer decoder attends to an input image using queries with supplied positional embeddings.
|
||||
It is designed for tasks like object detection, image segmentation, and point cloud processing.
|
||||
|
||||
Args:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for the input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
||||
mlp_dim (int): Channel dimension internal to the MLP block.
|
||||
activation (Type[nn.Module]): Activation function to use in the MLP block.
|
||||
attention_downsample_rate (int): Downsampling rate for attention computations.
|
||||
|
||||
Attributes:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for the input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> transformer
|
||||
TwoWayTransformer(
|
||||
(layers): ModuleList(
|
||||
(0-4): 5 x TwoWayAttentionBlock(...)
|
||||
)
|
||||
(final_attn_token_to_image): Attention(...)
|
||||
(norm_final_attn): LayerNorm(...)
|
||||
)
|
||||
"""
|
||||
super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate)
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RoPEAttention(Attention):
|
||||
"""Implements rotary position encoding for attention mechanisms in transformer architectures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
rope_theta=10000.0,
|
||||
# whether to repeat q rope to match k length
|
||||
# this is needed for cross-attention to memories
|
||||
rope_k_repeat=False,
|
||||
feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
|
||||
**kwargs,
|
||||
):
|
||||
"""Initializes RoPEAttention with rotary position encoding for attention mechanisms."""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta)
|
||||
freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
|
||||
self.freqs_cis = freqs_cis
|
||||
self.rope_k_repeat = rope_k_repeat
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor:
|
||||
"""Applies rotary position encoding and computes attention between query, key, and value tensors."""
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Apply rotary position encoding
|
||||
w = h = math.sqrt(q.shape[-2])
|
||||
self.freqs_cis = self.freqs_cis.to(q.device)
|
||||
if self.freqs_cis.shape[0] != q.shape[-2]:
|
||||
self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
|
||||
if q.shape[-2] != k.shape[-2]:
|
||||
assert self.rope_k_repeat
|
||||
|
||||
num_k_rope = k.size(-2) - num_k_exclude_rope
|
||||
q, k[:, :, :num_k_rope] = apply_rotary_enc(
|
||||
q,
|
||||
k[:, :, :num_k_rope],
|
||||
freqs_cis=self.freqs_cis,
|
||||
repeat_freqs_k=self.rope_k_repeat,
|
||||
)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
|
||||
out = self._recombine_heads(out)
|
||||
out = self.out_proj(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
|
||||
"""Applies pooling and optional normalization to a tensor, handling permutations for spatial operations."""
|
||||
if pool is None:
|
||||
return x
|
||||
# (B, H, W, C) -> (B, C, H, W)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
x = pool(x)
|
||||
# (B, C, H', W') -> (B, H', W', C)
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
if norm:
|
||||
x = norm(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleAttention(nn.Module):
|
||||
"""Implements multi-scale self-attention with optional query pooling for efficient feature extraction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
num_heads: int,
|
||||
q_pool: nn.Module = None,
|
||||
):
|
||||
"""Initializes a multi-scale attention module with configurable query pooling and linear projections."""
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim_out // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.q_pool = q_pool
|
||||
self.qkv = nn.Linear(dim, dim_out * 3)
|
||||
self.proj = nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies multi-scale attention to input tensor, optionally downsampling query features."""
|
||||
B, H, W, _ = x.shape
|
||||
# qkv with shape (B, H * W, 3, nHead, C)
|
||||
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
|
||||
# q, k, v with shape (B, H * W, nheads, C)
|
||||
q, k, v = torch.unbind(qkv, 2)
|
||||
|
||||
# Q pooling (for downsample at stage changes)
|
||||
if self.q_pool:
|
||||
q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
|
||||
H, W = q.shape[1:3] # downsampled shape
|
||||
q = q.reshape(B, H * W, self.num_heads, -1)
|
||||
|
||||
# Torch's SDPA expects [B, nheads, H*W, C] so we transpose
|
||||
x = F.scaled_dot_product_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
)
|
||||
# Transpose back
|
||||
x = x.transpose(1, 2)
|
||||
x = x.reshape(B, H, W, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleBlock(nn.Module):
|
||||
"""Multiscale attention block with window partitioning and query pooling for efficient vision transformers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_path: float = 0.0,
|
||||
norm_layer: Union[nn.Module, str] = "LayerNorm",
|
||||
q_stride: Tuple[int, int] = None,
|
||||
act_layer: nn.Module = nn.GELU,
|
||||
window_size: int = 0,
|
||||
):
|
||||
"""Initializes a multi-scale attention block with optional window partitioning and downsampling."""
|
||||
super().__init__()
|
||||
|
||||
if isinstance(norm_layer, str):
|
||||
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
|
||||
|
||||
self.dim = dim
|
||||
self.dim_out = dim_out
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.window_size = window_size
|
||||
|
||||
self.pool, self.q_stride = None, q_stride
|
||||
if self.q_stride:
|
||||
self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False)
|
||||
|
||||
self.attn = MultiScaleAttention(
|
||||
dim,
|
||||
dim_out,
|
||||
num_heads=num_heads,
|
||||
q_pool=self.pool,
|
||||
)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim_out)
|
||||
self.mlp = MLP(
|
||||
dim_out,
|
||||
int(dim_out * mlp_ratio),
|
||||
dim_out,
|
||||
num_layers=2,
|
||||
act=act_layer,
|
||||
)
|
||||
|
||||
if dim != dim_out:
|
||||
self.proj = nn.Linear(dim, dim_out)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Applies multi-scale attention and MLP processing to input tensor, with optional windowing."""
|
||||
shortcut = x # B, H, W, C
|
||||
x = self.norm1(x)
|
||||
|
||||
# Skip connection
|
||||
if self.dim != self.dim_out:
|
||||
shortcut = do_pool(self.proj(x), self.pool)
|
||||
|
||||
# Window partition
|
||||
window_size = self.window_size
|
||||
if window_size > 0:
|
||||
H, W = x.shape[1], x.shape[2]
|
||||
x, pad_hw = window_partition(x, window_size)
|
||||
|
||||
# Window Attention + Q Pooling (if stage change)
|
||||
x = self.attn(x)
|
||||
if self.q_stride:
|
||||
# Shapes have changed due to Q pooling
|
||||
window_size = self.window_size // self.q_stride[0]
|
||||
H, W = shortcut.shape[1:3]
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
pad_hw = (H + pad_h, W + pad_w)
|
||||
|
||||
# Reverse window partition
|
||||
if self.window_size > 0:
|
||||
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
# MLP
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""Generates sinusoidal positional embeddings for 2D inputs like images."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_pos_feats,
|
||||
temperature: int = 10000,
|
||||
normalize: bool = True,
|
||||
scale: Optional[float] = None,
|
||||
):
|
||||
"""Initializes sinusoidal position embeddings for 2D image inputs."""
|
||||
super().__init__()
|
||||
assert num_pos_feats % 2 == 0, "Expecting even model width"
|
||||
self.num_pos_feats = num_pos_feats // 2
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
self.cache = {}
|
||||
|
||||
def _encode_xy(self, x, y):
|
||||
"""Encodes 2D positions using sine and cosine functions for positional embeddings."""
|
||||
assert len(x) == len(y) and x.ndim == y.ndim == 1
|
||||
x_embed = x * self.scale
|
||||
y_embed = y * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, None] / dim_t
|
||||
pos_y = y_embed[:, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
|
||||
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
|
||||
return pos_x, pos_y
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_boxes(self, x, y, w, h):
|
||||
"""Encodes box coordinates and dimensions into positional embeddings for object detection tasks."""
|
||||
pos_x, pos_y = self._encode_xy(x, y)
|
||||
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
|
||||
return pos
|
||||
|
||||
encode = encode_boxes # Backwards compatibility
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_points(self, x, y, labels):
|
||||
"""Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels."""
|
||||
(bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
|
||||
assert bx == by and nx == ny and bx == bl and nx == nl
|
||||
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
|
||||
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
|
||||
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
|
||||
return pos
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Generate sinusoidal position embeddings for 2D inputs."""
|
||||
cache_key = (x.shape[-2], x.shape[-1])
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
|
||||
y_embed = (
|
||||
torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
|
||||
.view(1, -1, 1)
|
||||
.repeat(x.shape[0], 1, x.shape[-1])
|
||||
)
|
||||
x_embed = (
|
||||
torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
|
||||
.view(1, 1, -1)
|
||||
.repeat(x.shape[0], x.shape[-2], 1)
|
||||
)
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
self.cache[cache_key] = pos[0]
|
||||
return pos
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import torch
|
||||
|
||||
from ..sam.predict import Predictor
|
||||
from .build import build_sam2
|
||||
|
||||
|
||||
class SAM2Predictor(Predictor):
|
||||
"""
|
||||
A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class.
|
||||
|
||||
This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's
|
||||
advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask
|
||||
generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks.
|
||||
|
||||
Attributes:
|
||||
cfg (Dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (Dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (Dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (Dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
model (torch.nn.Module): The loaded SAM2 model.
|
||||
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
||||
_bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels.
|
||||
|
||||
Methods:
|
||||
get_model: Builds and returns the SAM2 model.
|
||||
prompt_inference: Performs image segmentation inference based on various prompts.
|
||||
set_image: Preprocesses and sets a single image for inference.
|
||||
get_im_features: Extracts image features from the SAM2 image encoder.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(model='sam2_l.pt')
|
||||
>>> predictor.set_image('path/to/image.jpg')
|
||||
>>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1])
|
||||
>>> print(f"Generated {len(masks)} mask(s) with scores: {scores}")
|
||||
"""
|
||||
|
||||
_bb_feat_sizes = [
|
||||
(256, 256),
|
||||
(128, 128),
|
||||
(64, 64),
|
||||
]
|
||||
|
||||
def get_model(self):
|
||||
"""Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks."""
|
||||
return build_sam2(self.args.model)
|
||||
|
||||
def prompt_inference(
|
||||
self,
|
||||
im,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
labels=None,
|
||||
masks=None,
|
||||
multimask_output=False,
|
||||
img_idx=-1,
|
||||
):
|
||||
"""
|
||||
Performs image segmentation inference based on various prompts using SAM2 architecture.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
img_idx (int): Index of the image in the batch to process.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
>>> image = torch.rand(1, 3, 640, 640)
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
|
||||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes = bboxes.view(-1, 2, 2) * r
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
points=points,
|
||||
boxes=None,
|
||||
masks=masks,
|
||||
)
|
||||
# Predict masks
|
||||
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
|
||||
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
|
||||
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
|
||||
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
|
||||
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=batched_mode,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor()
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> predictor.set_image(np.array([...])) # Using a numpy array
|
||||
"""
|
||||
if self.model is None:
|
||||
self.setup_model(model=None)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.get_im_features(im)
|
||||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks."""
|
||||
backbone_out = self.model.forward_image(im)
|
||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
||||
if self.model.directly_add_no_mem_embed:
|
||||
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
||||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||
][::-1]
|
||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||
Loading…
Add table
Add a link
Reference in a new issue