From f6c378835b0904b785cf19cdc492bdcfbd35bbf2 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:17:46 +0800 Subject: [PATCH] `ultralytics 8.3.22` SAM2.1 integration (#17131) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- docs/en/models/sam-2.md | 28 ++-- tests/test_cuda.py | 2 +- ultralytics/__init__.py | 2 +- ultralytics/cfg/__init__.py | 2 +- ultralytics/models/sam/build.py | 8 + ultralytics/models/sam/modules/sam.py | 207 ++++++++++++++++++-------- 6 files changed, 172 insertions(+), 77 deletions(-) diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index 5120498e..d5e8888e 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -1,9 +1,13 @@ --- comments: true description: Discover SAM 2, the next generation of Meta's Segment Anything Model, supporting real-time promptable segmentation in both images and videos with state-of-the-art performance. Learn about its key features, datasets, and how to use it. -keywords: SAM 2, Segment Anything, video segmentation, image segmentation, promptable segmentation, zero-shot performance, SA-V dataset, Ultralytics, real-time segmentation, AI, machine learning +keywords: SAM 2, SAM 2.1, Segment Anything, video segmentation, image segmentation, promptable segmentation, zero-shot performance, SA-V dataset, Ultralytics, real-time segmentation, AI, machine learning --- +!!! tip "SAM 2.1" + + We have just supported the more accurate SAM2.1 model. Please give it a try! + # SAM 2: Segment Anything Model 2 SAM 2, the successor to Meta's [Segment Anything Model (SAM)](sam.md), is a cutting-edge tool designed for comprehensive object segmentation in both images and videos. It excels in handling complex visual data through a unified, promptable model architecture that supports real-time processing and zero-shot generalization. @@ -114,12 +118,16 @@ pip install ultralytics The following table details the available SAM 2 models, their pre-trained weights, supported tasks, and compatibility with different operating modes like [Inference](../modes/predict.md), [Validation](../modes/val.md), [Training](../modes/train.md), and [Export](../modes/export.md). -| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | -| ----------- | ------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ | -| SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | -| SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | -| SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | -| SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | +| ------------- | ----------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ | +| SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2.1 tiny | [sam2.1_t.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2.1 small | [sam2.1_s.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2.1 base | [sam2.1_b.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2.1 large | [sam2.1_l.pt](https://github.com/ultralytics/assets/releases/download/v8.3.0/sam2.1_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | ### SAM 2 Prediction Examples @@ -137,7 +145,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide from ultralytics import SAM # Load a model - model = SAM("sam2_b.pt") + model = SAM("sam2.1_b.pt") # Display model information (optional) model.info() @@ -170,7 +178,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide from ultralytics import SAM # Load a model - model = SAM("sam2_b.pt") + model = SAM("sam2.1_b.pt") # Display model information (optional) model.info() @@ -183,7 +191,7 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide ```bash # Run inference with a SAM 2 model - yolo predict model=sam2_b.pt source=path/to/video.mp4 + yolo predict model=sam2.1_b.pt source=path/to/video.mp4 ``` - This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided. diff --git a/tests/test_cuda.py b/tests/test_cuda.py index 89f8c39b..4fd1a7ae 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -116,7 +116,7 @@ def test_predict_sam(): from ultralytics.models.sam import Predictor as SAMPredictor # Load a model - model = SAM(WEIGHTS_DIR / "sam_b.pt") + model = SAM(WEIGHTS_DIR / "sam2.1_b.pt") # Display model information (optional) model.info() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index ac22fe86..b9f92151 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.21" +__version__ = "8.3.22" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 153ab27e..7c36a3b9 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -787,7 +787,7 @@ def entrypoint(debug=""): from ultralytics import FastSAM model = FastSAM(model) - elif "sam_" in stem or "sam2_" in stem: + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: from ultralytics import SAM model = SAM(model) diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index e1105312..cee5133a 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -263,6 +263,7 @@ def _build_sam2( memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) memory_encoder = MemoryEncoder(out_dim=64) + is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint sam2 = SAM2Model( image_encoder=image_encoder, memory_attention=memory_attention, @@ -288,6 +289,9 @@ def _build_sam2( multimask_max_pt_num=1, use_mlp_for_obj_ptr_proj=True, compile_image_encoder=False, + no_obj_embed_spatial=is_sam2_1, + proj_tpos_enc_in_obj_ptrs=is_sam2_1, + use_signed_tpos_enc_to_obj_ptrs=is_sam2_1, sam_mask_decoder_extra_args=dict( dynamic_multimask_via_stability=True, dynamic_multimask_stability_delta=0.05, @@ -313,6 +317,10 @@ sam_model_map = { "sam2_s.pt": build_sam2_s, "sam2_b.pt": build_sam2_b, "sam2_l.pt": build_sam2_l, + "sam2.1_t.pt": build_sam2_t, + "sam2.1_s.pt": build_sam2_s, + "sam2.1_b.pt": build_sam2_b, + "sam2.1_l.pt": build_sam2_l, } diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 2728b0b4..562314b2 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -161,18 +161,19 @@ class SAM2Model(torch.nn.Module): use_multimask_token_for_obj_ptr: bool = False, iou_prediction_use_sigmoid=False, memory_temporal_stride_for_eval=1, - add_all_frames_to_correct_as_cond=False, non_overlap_masks_for_mem_enc=False, use_obj_ptrs_in_encoder=False, max_obj_ptrs_in_encoder=16, add_tpos_enc_to_obj_ptrs=True, proj_tpos_enc_in_obj_ptrs=False, + use_signed_tpos_enc_to_obj_ptrs=False, only_obj_ptrs_in_the_past_for_eval=False, pred_obj_scores: bool = False, pred_obj_scores_mlp: bool = False, fixed_no_obj_ptr: bool = False, soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, + no_obj_embed_spatial: bool = False, sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, ): @@ -205,8 +206,6 @@ class SAM2Model(torch.nn.Module): use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. - add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning - frame list. non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation. use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. @@ -216,6 +215,9 @@ class SAM2Model(torch.nn.Module): the encoder. proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional encoding in object pointers. + use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance) + in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True` + and `add_tpos_enc_to_obj_ptrs=True`. only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during evaluation. pred_obj_scores (bool): Whether to predict if there is an object in the frame. @@ -223,6 +225,7 @@ class SAM2Model(torch.nn.Module): fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. + no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames. sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. compile_image_encoder (bool): Whether to compile the image encoder for faster inference. @@ -253,6 +256,7 @@ class SAM2Model(torch.nn.Module): if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features @@ -309,9 +313,12 @@ class SAM2Model(torch.nn.Module): self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) self._build_sam_heads() - self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation @@ -533,8 +540,6 @@ class SAM2Model(torch.nn.Module): if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: - # Only hard possible with gt - assert not self.teacher_force_obj_scores_for_mem lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() @@ -647,6 +652,7 @@ class SAM2Model(torch.nn.Module): if self.num_maskmem == 0: # Disable memory and skip fusion return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone @@ -664,7 +670,7 @@ class SAM2Model(torch.nn.Module): # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # We also allow taking the memory frame non-consecutively (with r>1), in which case # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. - r = self.memory_temporal_stride_for_eval + r = 1 if self.training else self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: @@ -718,7 +724,14 @@ class SAM2Model(torch.nn.Module): ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame - (abs(frame_idx - t), out["obj_ptr"]) + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame @@ -787,6 +800,7 @@ class SAM2Model(torch.nn.Module): current_vision_feats, feat_sizes, pred_masks_high_res, + object_score_logits, is_mask_from_pts, ): """Encodes frame features and masks into a new memory representation for video segmentation.""" @@ -819,9 +833,102 @@ class SAM2Model(torch.nn.Module): ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[ + ..., None, None + ].expand(*maskmem_features.shape) return maskmem_features, maskmem_pos_enc + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + prev_sam_mask_logits, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be + used in future frames). + """ + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + def track_step( self, frame_idx, @@ -844,48 +951,20 @@ class SAM2Model(torch.nn.Module): prev_sam_mask_logits=None, ): """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" - current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} - # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW - if len(current_vision_feats) > 1: - high_res_features = [ - x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) - for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) - ] - else: - high_res_features = None - if mask_inputs is not None and self.use_mask_input_as_output_without_sam: - # When use_mask_input_as_output_without_sam=True, we directly output the mask input - # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. - pix_feat = current_vision_feats[-1].permute(1, 2, 0) - pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) - sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) - else: - # fused the visual feature with previous memory features in the memory bank - pix_feat_with_mem = self._prepare_memory_conditioned_features( - frame_idx=frame_idx, - is_init_cond_frame=is_init_cond_frame, - current_vision_feats=current_vision_feats[-1:], - current_vision_pos_embeds=current_vision_pos_embeds[-1:], - feat_sizes=feat_sizes[-1:], - output_dict=output_dict, - num_frames=num_frames, - track_in_reverse=track_in_reverse, - ) - # apply SAM-style segmentation head - # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, - # e.g. in demo where such logits come from earlier interaction instead of correction sampling - # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) - if prev_sam_mask_logits is not None: - assert point_inputs is not None and mask_inputs is None - mask_inputs = prev_sam_mask_logits - multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) - sam_outputs = self._forward_sam_heads( - backbone_features=pix_feat_with_mem, - point_inputs=point_inputs, - mask_inputs=mask_inputs, - high_res_features=high_res_features, - multimask_output=multimask_output, - ) + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + ( _, _, @@ -893,28 +972,28 @@ class SAM2Model(torch.nn.Module): low_res_masks, high_res_masks, obj_ptr, - _, + object_score_logits, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) - if run_mem_encoder and self.num_maskmem > 0: - high_res_masks_for_mem_enc = high_res_masks - maskmem_features, maskmem_pos_enc = self._encode_new_memory( - current_vision_feats=current_vision_feats, - feat_sizes=feat_sizes, - pred_masks_high_res=high_res_masks_for_mem_enc, - is_mask_from_pts=(point_inputs is not None), - ) - current_out["maskmem_features"] = maskmem_features - current_out["maskmem_pos_enc"] = maskmem_pos_enc - else: - current_out["maskmem_features"] = None - current_out["maskmem_pos_enc"] = None + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) return current_out