ultralytics 8.3.38 SAM 2 video inference (#14851)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Laughing 2024-11-26 19:38:23 +08:00 committed by GitHub
parent 407815cf9e
commit dcc9bd536f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 917 additions and 124 deletions

View file

@ -36,8 +36,6 @@ class SAMModel(nn.Module):
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
Methods:
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module):
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
# build PromptEncoder and MaskDecoder from SAM
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
# Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module):
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
object_score_logits: Tensor of shape (B,) with object score logits.
@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module):
boxes=None,
masks=sam_mask_prompt,
)
(
low_res_multimasks,
ious,
sam_output_tokens,
object_score_logits,
) = self.sam_mask_decoder(
low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module):
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
# consistent with the actual mask prediction
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module):
def _prepare_backbone_features(self, backbone_out):
"""Prepares and flattens visual features from the image backbone output for further processing."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module):
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
maskmem_out = self.memory_encoder(
pix_feat,
mask_for_mem,
skip_mask_sigmoid=True, # sigmoid already applied
)
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module):
track_in_reverse,
prev_sam_mask_logits,
)
(
_,
_,
_,
low_res_masks,
high_res_masks,
obj_ptr,
object_score_logits,
) = sam_outputs
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module):
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
# Finally run the memory encoder on the predicted mask to encode
# it into a new memory feature (that can be used in future frames)
# Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module):
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
)
def _apply_non_overlapping_constraints(self, pred_masks):
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
@staticmethod
def _apply_non_overlapping_constraints(pred_masks):
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module):
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks
def set_binarize(self, binarize=False):
"""Set binarize for VideoPredictor."""
self.binarize_mask_from_pts_for_mem_enc = binarize
def set_imgsz(self, imgsz):
"""
Set image size to make model compatible with different image sizes.