ultralytics 8.3.38 SAM 2 video inference (#14851)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
407815cf9e
commit
dcc9bd536f
16 changed files with 917 additions and 124 deletions
|
|
@ -36,8 +36,6 @@ class SAMModel(nn.Module):
|
|||
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
||||
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
|
||||
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
|
||||
|
|
@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module):
|
|||
self.sam_prompt_embed_dim = self.hidden_dim
|
||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
||||
|
||||
# build PromptEncoder and MaskDecoder from SAM
|
||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
# Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
self.sam_prompt_encoder = PromptEncoder(
|
||||
embed_dim=self.sam_prompt_embed_dim,
|
||||
image_embedding_size=(
|
||||
|
|
@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module):
|
|||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||
|
||||
|
|
@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module):
|
|||
boxes=None,
|
||||
masks=sam_mask_prompt,
|
||||
)
|
||||
(
|
||||
low_res_multimasks,
|
||||
ious,
|
||||
sam_output_tokens,
|
||||
object_score_logits,
|
||||
) = self.sam_mask_decoder(
|
||||
low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
|
||||
image_embeddings=backbone_features,
|
||||
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
|
|
@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module):
|
|||
if self.pred_obj_scores:
|
||||
is_obj_appearing = object_score_logits > 0
|
||||
|
||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||
# consistent with the actual mask prediction
|
||||
low_res_multimasks = torch.where(
|
||||
is_obj_appearing[:, None, None],
|
||||
low_res_multimasks,
|
||||
NO_OBJ_SCORE,
|
||||
)
|
||||
# Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
|
||||
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
|
||||
|
||||
# convert masks from possibly bfloat16 (or float16) to float32
|
||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||
|
|
@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module):
|
|||
|
||||
def _prepare_backbone_features(self, backbone_out):
|
||||
"""Prepares and flattens visual features from the image backbone output for further processing."""
|
||||
backbone_out = backbone_out.copy()
|
||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||
|
||||
|
|
@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module):
|
|||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(
|
||||
pix_feat,
|
||||
mask_for_mem,
|
||||
skip_mask_sigmoid=True, # sigmoid already applied
|
||||
)
|
||||
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||
|
|
@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module):
|
|||
track_in_reverse,
|
||||
prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
) = sam_outputs
|
||||
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
|
||||
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
|
|
@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module):
|
|||
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
||||
current_out["object_score_logits"] = object_score_logits
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
# Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
|
||||
self._encode_memory_in_output(
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
|
|
@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module):
|
|||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||
)
|
||||
|
||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
|
||||
@staticmethod
|
||||
def _apply_non_overlapping_constraints(pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
||||
batch_size = pred_masks.size(0)
|
||||
if batch_size == 1:
|
||||
return pred_masks
|
||||
|
|
@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module):
|
|||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
||||
def set_binarize(self, binarize=False):
|
||||
"""Set binarize for VideoPredictor."""
|
||||
self.binarize_mask_from_pts_for_mem_enc = binarize
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue