ultralytics 8.3.22 SAM2.1 integration (#17131)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-10-25 08:17:46 +08:00 committed by GitHub
parent 55eec8347f
commit f6c378835b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 172 additions and 77 deletions

View file

@ -1,9 +1,13 @@
--- ---
comments: true comments: true
description: Discover SAM 2, the next generation of Meta's Segment Anything Model, supporting real-time promptable segmentation in both images and videos with state-of-the-art performance. Learn about its key features, datasets, and how to use it. description: Discover SAM 2, the next generation of Meta's Segment Anything Model, supporting real-time promptable segmentation in both images and videos with state-of-the-art performance. Learn about its key features, datasets, and how to use it.
keywords: SAM 2, Segment Anything, video segmentation, image segmentation, promptable segmentation, zero-shot performance, SA-V dataset, Ultralytics, real-time segmentation, AI, machine learning keywords: SAM 2, SAM 2.1, Segment Anything, video segmentation, image segmentation, promptable segmentation, zero-shot performance, SA-V dataset, Ultralytics, real-time segmentation, AI, machine learning
--- ---
!!! tip "SAM 2.1"
We have just supported the more accurate SAM2.1 model. Please give it a try!
# SAM 2: Segment Anything Model 2 # SAM 2: Segment Anything Model 2
SAM 2, the successor to Meta's [Segment Anything Model (SAM)](sam.md), is a cutting-edge tool designed for comprehensive object segmentation in both images and videos. It excels in handling complex visual data through a unified, promptable model architecture that supports real-time processing and zero-shot generalization. SAM 2, the successor to Meta's [Segment Anything Model (SAM)](sam.md), is a cutting-edge tool designed for comprehensive object segmentation in both images and videos. It excels in handling complex visual data through a unified, promptable model architecture that supports real-time processing and zero-shot generalization.
@ -114,12 +118,16 @@ pip install ultralytics
The following table details the available SAM 2 models, their pre-trained weights, supported tasks, and compatibility with different operating modes like [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md). The following table details the available SAM 2 models, their pre-trained weights, supported tasks, and compatibility with different operating modes like [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md).
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | | Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
| ----------- | ------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ | | ------------- | ----------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ |
| SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2.1 tiny | [sam2.1_t.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2.1 small | [sam2.1_s.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2.1 base | [sam2.1_b.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
| SAM 2.1 large | [sam2.1_l.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ |
### SAM 2 Prediction Examples ### SAM 2 Prediction Examples
@ -137,7 +145,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
from ultralytics import SAM from ultralytics import SAM
# Load a model # Load a model
model = SAM("sam2_b.pt") model = SAM("sam2.1_b.pt")
# Display model information (optional) # Display model information (optional)
model.info() model.info()
@ -170,7 +178,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
from ultralytics import SAM from ultralytics import SAM
# Load a model # Load a model
model = SAM("sam2_b.pt") model = SAM("sam2.1_b.pt")
# Display model information (optional) # Display model information (optional)
model.info() model.info()
@ -183,7 +191,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
```bash ```bash
# Run inference with a SAM 2 model # Run inference with a SAM 2 model
yolo predict model=sam2_b.pt source=path/to/video.mp4 yolo predict model=sam2.1_b.pt source=path/to/video.mp4
``` ```
- This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided. - This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided.

View file

@ -116,7 +116,7 @@ def test_predict_sam():
from ultralytics.models.sam import Predictor as SAMPredictor from ultralytics.models.sam import Predictor as SAMPredictor
# Load a model # Load a model
model = SAM(WEIGHTS_DIR / "sam_b.pt") model = SAM(WEIGHTS_DIR / "sam2.1_b.pt")
# Display model information (optional) # Display model information (optional)
model.info() model.info()

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.21" __version__ = "8.3.22"
import os import os

View file

@ -787,7 +787,7 @@ def entrypoint(debug=""):
from ultralytics import FastSAM from ultralytics import FastSAM
model = FastSAM(model) model = FastSAM(model)
elif "sam_" in stem or "sam2_" in stem: elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem:
from ultralytics import SAM from ultralytics import SAM
model = SAM(model) model = SAM(model)

View file

@ -263,6 +263,7 @@ def _build_sam2(
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
memory_encoder = MemoryEncoder(out_dim=64) memory_encoder = MemoryEncoder(out_dim=64)
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
sam2 = SAM2Model( sam2 = SAM2Model(
image_encoder=image_encoder, image_encoder=image_encoder,
memory_attention=memory_attention, memory_attention=memory_attention,
@ -288,6 +289,9 @@ def _build_sam2(
multimask_max_pt_num=1, multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True, use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False, compile_image_encoder=False,
no_obj_embed_spatial=is_sam2_1,
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
sam_mask_decoder_extra_args=dict( sam_mask_decoder_extra_args=dict(
dynamic_multimask_via_stability=True, dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05, dynamic_multimask_stability_delta=0.05,
@ -313,6 +317,10 @@ sam_model_map = {
"sam2_s.pt": build_sam2_s, "sam2_s.pt": build_sam2_s,
"sam2_b.pt": build_sam2_b, "sam2_b.pt": build_sam2_b,
"sam2_l.pt": build_sam2_l, "sam2_l.pt": build_sam2_l,
"sam2.1_t.pt": build_sam2_t,
"sam2.1_s.pt": build_sam2_s,
"sam2.1_b.pt": build_sam2_b,
"sam2.1_l.pt": build_sam2_l,
} }

View file

@ -161,18 +161,19 @@ class SAM2Model(torch.nn.Module):
use_multimask_token_for_obj_ptr: bool = False, use_multimask_token_for_obj_ptr: bool = False,
iou_prediction_use_sigmoid=False, iou_prediction_use_sigmoid=False,
memory_temporal_stride_for_eval=1, memory_temporal_stride_for_eval=1,
add_all_frames_to_correct_as_cond=False,
non_overlap_masks_for_mem_enc=False, non_overlap_masks_for_mem_enc=False,
use_obj_ptrs_in_encoder=False, use_obj_ptrs_in_encoder=False,
max_obj_ptrs_in_encoder=16, max_obj_ptrs_in_encoder=16,
add_tpos_enc_to_obj_ptrs=True, add_tpos_enc_to_obj_ptrs=True,
proj_tpos_enc_in_obj_ptrs=False, proj_tpos_enc_in_obj_ptrs=False,
use_signed_tpos_enc_to_obj_ptrs=False,
only_obj_ptrs_in_the_past_for_eval=False, only_obj_ptrs_in_the_past_for_eval=False,
pred_obj_scores: bool = False, pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False, pred_obj_scores_mlp: bool = False,
fixed_no_obj_ptr: bool = False, fixed_no_obj_ptr: bool = False,
soft_no_obj_ptr: bool = False, soft_no_obj_ptr: bool = False,
use_mlp_for_obj_ptr_proj: bool = False, use_mlp_for_obj_ptr_proj: bool = False,
no_obj_embed_spatial: bool = False,
sam_mask_decoder_extra_args=None, sam_mask_decoder_extra_args=None,
compile_image_encoder: bool = False, compile_image_encoder: bool = False,
): ):
@ -205,8 +206,6 @@ class SAM2Model(torch.nn.Module):
use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1].
memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation.
add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning
frame list.
non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in
memory encoder during evaluation. memory encoder during evaluation.
use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder.
@ -216,6 +215,9 @@ class SAM2Model(torch.nn.Module):
the encoder. the encoder.
proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional
encoding in object pointers. encoding in object pointers.
use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance)
in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True`
and `add_tpos_enc_to_obj_ptrs=True`.
only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past
during evaluation. during evaluation.
pred_obj_scores (bool): Whether to predict if there is an object in the frame. pred_obj_scores (bool): Whether to predict if there is an object in the frame.
@ -223,6 +225,7 @@ class SAM2Model(torch.nn.Module):
fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present.
soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation.
use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection.
no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames.
sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder.
compile_image_encoder (bool): Whether to compile the image encoder for faster inference. compile_image_encoder (bool): Whether to compile the image encoder for faster inference.
@ -253,6 +256,7 @@ class SAM2Model(torch.nn.Module):
if proj_tpos_enc_in_obj_ptrs: if proj_tpos_enc_in_obj_ptrs:
assert add_tpos_enc_to_obj_ptrs # these options need to be used together assert add_tpos_enc_to_obj_ptrs # these options need to be used together
self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs
self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
# Part 2: memory attention to condition current frame's visual features # Part 2: memory attention to condition current frame's visual features
@ -309,9 +313,12 @@ class SAM2Model(torch.nn.Module):
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
trunc_normal_(self.no_obj_ptr, std=0.02) trunc_normal_(self.no_obj_ptr, std=0.02)
self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
self.no_obj_embed_spatial = None
if no_obj_embed_spatial:
self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim))
trunc_normal_(self.no_obj_embed_spatial, std=0.02)
self._build_sam_heads() self._build_sam_heads()
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
self.max_cond_frames_in_attn = max_cond_frames_in_attn self.max_cond_frames_in_attn = max_cond_frames_in_attn
# Model compilation # Model compilation
@ -533,8 +540,6 @@ class SAM2Model(torch.nn.Module):
if self.pred_obj_scores: if self.pred_obj_scores:
# Allow *soft* no obj ptr, unlike for masks # Allow *soft* no obj ptr, unlike for masks
if self.soft_no_obj_ptr: if self.soft_no_obj_ptr:
# Only hard possible with gt
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid() lambda_is_obj_appearing = object_score_logits.sigmoid()
else: else:
lambda_is_obj_appearing = is_obj_appearing.float() lambda_is_obj_appearing = is_obj_appearing.float()
@ -647,6 +652,7 @@ class SAM2Model(torch.nn.Module):
if self.num_maskmem == 0: # Disable memory and skip fusion if self.num_maskmem == 0: # Disable memory and skip fusion
return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
num_obj_ptr_tokens = 0 num_obj_ptr_tokens = 0
tpos_sign_mul = -1 if track_in_reverse else 1
# Step 1: condition the visual features of the current frame on previous memories # Step 1: condition the visual features of the current frame on previous memories
if not is_init_cond_frame: if not is_init_cond_frame:
# Retrieve the memories encoded with the maskmem backbone # Retrieve the memories encoded with the maskmem backbone
@ -664,7 +670,7 @@ class SAM2Model(torch.nn.Module):
# the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
# We also allow taking the memory frame non-consecutively (with r>1), in which case # We also allow taking the memory frame non-consecutively (with r>1), in which case
# we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
r = self.memory_temporal_stride_for_eval r = 1 if self.training else self.memory_temporal_stride_for_eval
for t_pos in range(1, self.num_maskmem): for t_pos in range(1, self.num_maskmem):
t_rel = self.num_maskmem - t_pos # how many frames before current frame t_rel = self.num_maskmem - t_pos # how many frames before current frame
if t_rel == 1: if t_rel == 1:
@ -718,7 +724,14 @@ class SAM2Model(torch.nn.Module):
ptr_cond_outputs = selected_cond_outputs ptr_cond_outputs = selected_cond_outputs
pos_and_ptrs = [ pos_and_ptrs = [
# Temporal pos encoding contains how far away each pointer is from current frame # Temporal pos encoding contains how far away each pointer is from current frame
(abs(frame_idx - t), out["obj_ptr"]) (
(
(frame_idx - t) * tpos_sign_mul
if self.use_signed_tpos_enc_to_obj_ptrs
else abs(frame_idx - t)
),
out["obj_ptr"],
)
for t, out in ptr_cond_outputs.items() for t, out in ptr_cond_outputs.items()
] ]
# Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
@ -787,6 +800,7 @@ class SAM2Model(torch.nn.Module):
current_vision_feats, current_vision_feats,
feat_sizes, feat_sizes,
pred_masks_high_res, pred_masks_high_res,
object_score_logits,
is_mask_from_pts, is_mask_from_pts,
): ):
"""Encodes frame features and masks into a new memory representation for video segmentation.""" """Encodes frame features and masks into a new memory representation for video segmentation."""
@ -819,9 +833,102 @@ class SAM2Model(torch.nn.Module):
) )
maskmem_features = maskmem_out["vision_features"] maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"] maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
# is predicted to be occluded (i.e. no object is appearing in the frame)
if self.no_obj_embed_spatial is not None:
is_obj_appearing = (object_score_logits > 0).float()
maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[
..., None, None
].expand(*maskmem_features.shape)
return maskmem_features, maskmem_pos_enc return maskmem_features, maskmem_pos_enc
def _track_step(
self,
frame_idx,
is_init_cond_frame,
current_vision_feats,
current_vision_pos_embeds,
feat_sizes,
point_inputs,
mask_inputs,
output_dict,
num_frames,
prev_sam_mask_logits,
):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
if len(current_vision_feats) > 1:
high_res_features = [
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
]
else:
high_res_features = None
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
return current_out, sam_outputs, high_res_features, pix_feat
def _encode_memory_in_output(
self,
current_vision_feats,
feat_sizes,
point_inputs,
run_mem_encoder,
high_res_masks,
object_score_logits,
current_out,
):
"""Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be
used in future frames).
"""
if run_mem_encoder and self.num_maskmem > 0:
high_res_masks_for_mem_enc = high_res_masks
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
current_vision_feats=current_vision_feats,
feat_sizes=feat_sizes,
pred_masks_high_res=high_res_masks_for_mem_enc,
object_score_logits=object_score_logits,
is_mask_from_pts=(point_inputs is not None),
)
current_out["maskmem_features"] = maskmem_features
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
def track_step( def track_step(
self, self,
frame_idx, frame_idx,
@ -844,48 +951,20 @@ class SAM2Model(torch.nn.Module):
prev_sam_mask_logits=None, prev_sam_mask_logits=None,
): ):
"""Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" """Performs a single tracking step, updating object masks and memory features based on current frame inputs."""
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} current_out, sam_outputs, _, _ = self._track_step(
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW frame_idx,
if len(current_vision_feats) > 1: is_init_cond_frame,
high_res_features = [ current_vision_feats,
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) current_vision_pos_embeds,
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) feat_sizes,
] point_inputs,
else: mask_inputs,
high_res_features = None output_dict,
if mask_inputs is not None and self.use_mask_input_as_output_without_sam: num_frames,
# When use_mask_input_as_output_without_sam=True, we directly output the mask input track_in_reverse,
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder. prev_sam_mask_logits,
pix_feat = current_vision_feats[-1].permute(1, 2, 0) )
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs)
else:
# fused the visual feature with previous memory features in the memory bank
pix_feat_with_mem = self._prepare_memory_conditioned_features(
frame_idx=frame_idx,
is_init_cond_frame=is_init_cond_frame,
current_vision_feats=current_vision_feats[-1:],
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
feat_sizes=feat_sizes[-1:],
output_dict=output_dict,
num_frames=num_frames,
track_in_reverse=track_in_reverse,
)
# apply SAM-style segmentation head
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
if prev_sam_mask_logits is not None:
assert point_inputs is not None and mask_inputs is None
mask_inputs = prev_sam_mask_logits
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
sam_outputs = self._forward_sam_heads(
backbone_features=pix_feat_with_mem,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
high_res_features=high_res_features,
multimask_output=multimask_output,
)
( (
_, _,
_, _,
@ -893,28 +972,28 @@ class SAM2Model(torch.nn.Module):
low_res_masks, low_res_masks,
high_res_masks, high_res_masks,
obj_ptr, obj_ptr,
_, object_score_logits,
) = sam_outputs ) = sam_outputs
current_out["pred_masks"] = low_res_masks current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks current_out["pred_masks_high_res"] = high_res_masks
current_out["obj_ptr"] = obj_ptr current_out["obj_ptr"] = obj_ptr
if not self.training:
# Only add this in inference (to avoid unused param in activation checkpointing;
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode # Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames) # it into a new memory feature (that can be used in future frames)
if run_mem_encoder and self.num_maskmem > 0: self._encode_memory_in_output(
high_res_masks_for_mem_enc = high_res_masks current_vision_feats,
maskmem_features, maskmem_pos_enc = self._encode_new_memory( feat_sizes,
current_vision_feats=current_vision_feats, point_inputs,
feat_sizes=feat_sizes, run_mem_encoder,
pred_masks_high_res=high_res_masks_for_mem_enc, high_res_masks,
is_mask_from_pts=(point_inputs is not None), object_score_logits,
) current_out,
current_out["maskmem_features"] = maskmem_features )
current_out["maskmem_pos_enc"] = maskmem_pos_enc
else:
current_out["maskmem_features"] = None
current_out["maskmem_pos_enc"] = None
return current_out return current_out