ultralytics 8.3.38 SAM 2 video inference (#14851)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
407815cf9e
commit
dcc9bd536f
16 changed files with 917 additions and 124 deletions
|
|
@ -194,6 +194,34 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
|
||||||
yolo predict model=sam2.1_b.pt source=path/to/video.mp4
|
yolo predict model=sam2.1_b.pt source=path/to/video.mp4
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Segment Video and Track objects
|
||||||
|
|
||||||
|
!!! example "Segment Video"
|
||||||
|
|
||||||
|
Segment the entire video content with specific prompts and track objects.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ultralytics.models.sam import SAM2VideoPredictor
|
||||||
|
|
||||||
|
# Create SAM2VideoPredictor
|
||||||
|
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt")
|
||||||
|
predictor = SAM2VideoPredictor(overrides=overrides)
|
||||||
|
|
||||||
|
# Run inference with single point
|
||||||
|
results = predictor(source="test.mp4", points=[920, 470], labels=1)
|
||||||
|
|
||||||
|
# Run inference with multiple points
|
||||||
|
results = predictor(source="test.mp4", points=[[920, 470], [909, 138]], labels=[1, 1])
|
||||||
|
|
||||||
|
# Run inference with multiple points prompt per object
|
||||||
|
results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 1]])
|
||||||
|
|
||||||
|
# Run inference with negative points prompt
|
||||||
|
results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 0]])
|
||||||
|
```
|
||||||
|
|
||||||
- This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided.
|
- This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided.
|
||||||
|
|
||||||
## SAM 2 comparison vs YOLOv8
|
## SAM 2 comparison vs YOLOv8
|
||||||
|
|
|
||||||
|
|
@ -17,4 +17,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode
|
||||||
|
|
||||||
## ::: ultralytics.models.sam.predict.SAM2Predictor
|
## ::: ultralytics.models.sam.predict.SAM2Predictor
|
||||||
|
|
||||||
|
<br><br><hr><br>
|
||||||
|
|
||||||
|
## ::: ultralytics.models.sam.predict.SAM2VideoPredictor
|
||||||
|
|
||||||
<br><br>
|
<br><br>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.3.37"
|
__version__ = "8.3.38"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -740,9 +740,8 @@ def parse_key_value_pair(pair: str = "key=value"):
|
||||||
pair (str): A string containing a key-value pair in the format "key=value".
|
pair (str): A string containing a key-value pair in the format "key=value".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple containing two elements:
|
key (str): The parsed key.
|
||||||
- key (str): The parsed key.
|
value (str): The parsed value.
|
||||||
- value (str): The parsed value.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the value is missing or empty.
|
AssertionError: If the value is missing or empty.
|
||||||
|
|
|
||||||
|
|
@ -2111,10 +2111,9 @@ class Format:
|
||||||
h (int): Height of the image.
|
h (int): Height of the image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Tuple containing:
|
masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
|
||||||
masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
|
instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
|
||||||
instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
|
cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
|
||||||
cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
|
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- If self.mask_overlap is True, masks are overlapped and sorted by area.
|
- If self.mask_overlap is True, masks are overlapped and sorted by area.
|
||||||
|
|
|
||||||
|
|
@ -354,7 +354,7 @@ class LoadImagesAndVideos:
|
||||||
self.nf = ni + nv # number of files
|
self.nf = ni + nv # number of files
|
||||||
self.ni = ni # number of images
|
self.ni = ni # number of images
|
||||||
self.video_flag = [False] * ni + [True] * nv
|
self.video_flag = [False] * ni + [True] * nv
|
||||||
self.mode = "image"
|
self.mode = "video" if ni == 0 else "image" # default to video if no images
|
||||||
self.vid_stride = vid_stride # video frame-rate stride
|
self.vid_stride = vid_stride # video frame-rate stride
|
||||||
self.bs = batch
|
self.bs = batch
|
||||||
if any(videos):
|
if any(videos):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
from .model import SAM
|
from .model import SAM
|
||||||
from .predict import Predictor, SAM2Predictor
|
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
||||||
|
|
||||||
__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
|
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
|
||||||
|
|
|
||||||
|
|
@ -148,7 +148,7 @@ class SAM(Model):
|
||||||
verbose (bool): If True, prints the information to the console.
|
verbose (bool): If True, prints the information to the console.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple): A tuple containing the model's information (string representations of the model).
|
(tuple): A tuple containing the model's information (string representations of the model).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> sam = SAM("sam_b.pt")
|
>>> sam = SAM("sam_b.pt")
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,6 @@ class SAMModel(nn.Module):
|
||||||
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
||||||
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
||||||
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
||||||
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
|
|
||||||
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
|
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
|
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
|
||||||
|
|
@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
self.sam_prompt_embed_dim = self.hidden_dim
|
self.sam_prompt_embed_dim = self.hidden_dim
|
||||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
||||||
|
|
||||||
# build PromptEncoder and MaskDecoder from SAM
|
# Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
|
||||||
self.sam_prompt_encoder = PromptEncoder(
|
self.sam_prompt_encoder = PromptEncoder(
|
||||||
embed_dim=self.sam_prompt_embed_dim,
|
embed_dim=self.sam_prompt_embed_dim,
|
||||||
image_embedding_size=(
|
image_embedding_size=(
|
||||||
|
|
@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module):
|
||||||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
||||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
||||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
|
||||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
||||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||||
|
|
||||||
|
|
@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
boxes=None,
|
boxes=None,
|
||||||
masks=sam_mask_prompt,
|
masks=sam_mask_prompt,
|
||||||
)
|
)
|
||||||
(
|
low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
|
||||||
low_res_multimasks,
|
|
||||||
ious,
|
|
||||||
sam_output_tokens,
|
|
||||||
object_score_logits,
|
|
||||||
) = self.sam_mask_decoder(
|
|
||||||
image_embeddings=backbone_features,
|
image_embeddings=backbone_features,
|
||||||
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
||||||
sparse_prompt_embeddings=sparse_embeddings,
|
sparse_prompt_embeddings=sparse_embeddings,
|
||||||
|
|
@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module):
|
||||||
if self.pred_obj_scores:
|
if self.pred_obj_scores:
|
||||||
is_obj_appearing = object_score_logits > 0
|
is_obj_appearing = object_score_logits > 0
|
||||||
|
|
||||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
# Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
|
||||||
# consistent with the actual mask prediction
|
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
|
||||||
low_res_multimasks = torch.where(
|
|
||||||
is_obj_appearing[:, None, None],
|
|
||||||
low_res_multimasks,
|
|
||||||
NO_OBJ_SCORE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# convert masks from possibly bfloat16 (or float16) to float32
|
# convert masks from possibly bfloat16 (or float16) to float32
|
||||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||||
|
|
@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module):
|
||||||
|
|
||||||
def _prepare_backbone_features(self, backbone_out):
|
def _prepare_backbone_features(self, backbone_out):
|
||||||
"""Prepares and flattens visual features from the image backbone output for further processing."""
|
"""Prepares and flattens visual features from the image backbone output for further processing."""
|
||||||
backbone_out = backbone_out.copy()
|
|
||||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||||
|
|
||||||
|
|
@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||||
maskmem_out = self.memory_encoder(
|
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
||||||
pix_feat,
|
|
||||||
mask_for_mem,
|
|
||||||
skip_mask_sigmoid=True, # sigmoid already applied
|
|
||||||
)
|
|
||||||
maskmem_features = maskmem_out["vision_features"]
|
maskmem_features = maskmem_out["vision_features"]
|
||||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||||
|
|
@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
track_in_reverse,
|
track_in_reverse,
|
||||||
prev_sam_mask_logits,
|
prev_sam_mask_logits,
|
||||||
)
|
)
|
||||||
|
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
|
||||||
(
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
low_res_masks,
|
|
||||||
high_res_masks,
|
|
||||||
obj_ptr,
|
|
||||||
object_score_logits,
|
|
||||||
) = sam_outputs
|
|
||||||
|
|
||||||
current_out["pred_masks"] = low_res_masks
|
current_out["pred_masks"] = low_res_masks
|
||||||
current_out["pred_masks_high_res"] = high_res_masks
|
current_out["pred_masks_high_res"] = high_res_masks
|
||||||
|
|
@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
||||||
current_out["object_score_logits"] = object_score_logits
|
current_out["object_score_logits"] = object_score_logits
|
||||||
|
|
||||||
# Finally run the memory encoder on the predicted mask to encode
|
# Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
|
||||||
# it into a new memory feature (that can be used in future frames)
|
|
||||||
self._encode_memory_in_output(
|
self._encode_memory_in_output(
|
||||||
current_vision_feats,
|
current_vision_feats,
|
||||||
feat_sizes,
|
feat_sizes,
|
||||||
|
|
@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module):
|
||||||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
@staticmethod
|
||||||
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
|
def _apply_non_overlapping_constraints(pred_masks):
|
||||||
|
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
||||||
batch_size = pred_masks.size(0)
|
batch_size = pred_masks.size(0)
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
return pred_masks
|
return pred_masks
|
||||||
|
|
@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module):
|
||||||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||||
return pred_masks
|
return pred_masks
|
||||||
|
|
||||||
|
def set_binarize(self, binarize=False):
|
||||||
|
"""Set binarize for VideoPredictor."""
|
||||||
|
self.binarize_mask_from_pts_for_mem_enc = binarize
|
||||||
|
|
||||||
def set_imgsz(self, imgsz):
|
def set_imgsz(self, imgsz):
|
||||||
"""
|
"""
|
||||||
Set image size to make model compatible with different image sizes.
|
Set image size to make model compatible with different image sizes.
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
|
||||||
segmentation tasks.
|
segmentation tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -16,7 +18,7 @@ from ultralytics.data.augment import LetterBox
|
||||||
from ultralytics.engine.predictor import BasePredictor
|
from ultralytics.engine.predictor import BasePredictor
|
||||||
from ultralytics.engine.results import Results
|
from ultralytics.engine.results import Results
|
||||||
from ultralytics.utils import DEFAULT_CFG, ops
|
from ultralytics.utils import DEFAULT_CFG, ops
|
||||||
from ultralytics.utils.torch_utils import select_device
|
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||||
|
|
||||||
from .amg import (
|
from .amg import (
|
||||||
batch_iterator,
|
batch_iterator,
|
||||||
|
|
@ -95,7 +97,7 @@ class Predictor(BasePredictor):
|
||||||
"""
|
"""
|
||||||
if overrides is None:
|
if overrides is None:
|
||||||
overrides = {}
|
overrides = {}
|
||||||
overrides.update(dict(task="segment", mode="predict"))
|
overrides.update(dict(task="segment", mode="predict", batch=1))
|
||||||
super().__init__(cfg, overrides, _callbacks)
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
self.args.retina_masks = True
|
self.args.retina_masks = True
|
||||||
self.im = None
|
self.im = None
|
||||||
|
|
@ -114,7 +116,7 @@ class Predictor(BasePredictor):
|
||||||
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> predictor = Predictor()
|
>>> predictor = Predictor()
|
||||||
|
|
@ -181,10 +183,9 @@ class Predictor(BasePredictor):
|
||||||
**kwargs (Any): Additional keyword arguments.
|
**kwargs (Any): Additional keyword arguments.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Contains the following three elements:
|
(np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||||
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
|
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
(np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||||
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> predictor = Predictor()
|
>>> predictor = Predictor()
|
||||||
|
|
@ -222,10 +223,8 @@ class Predictor(BasePredictor):
|
||||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Tuple containing:
|
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
(np.ndarray): Quality scores predicted by the model for each mask, with length C.
|
||||||
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
|
|
||||||
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> predictor = Predictor()
|
>>> predictor = Predictor()
|
||||||
|
|
@ -329,10 +328,9 @@ class Predictor(BasePredictor):
|
||||||
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||||
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||||
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> predictor = Predictor()
|
>>> predictor = Predictor()
|
||||||
|
|
@ -408,7 +406,7 @@ class Predictor(BasePredictor):
|
||||||
|
|
||||||
return pred_masks, pred_scores, pred_bboxes
|
return pred_masks, pred_scores, pred_bboxes
|
||||||
|
|
||||||
def setup_model(self, model, verbose=True):
|
def setup_model(self, model=None, verbose=True):
|
||||||
"""
|
"""
|
||||||
Initializes the Segment Anything Model (SAM) for inference.
|
Initializes the Segment Anything Model (SAM) for inference.
|
||||||
|
|
||||||
|
|
@ -416,7 +414,7 @@ class Predictor(BasePredictor):
|
||||||
parameters for image normalization and other Ultralytics compatibility settings.
|
parameters for image normalization and other Ultralytics compatibility settings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
|
||||||
verbose (bool): If True, prints selected device information.
|
verbose (bool): If True, prints selected device information.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
@ -459,7 +457,7 @@ class Predictor(BasePredictor):
|
||||||
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||||
metadata for each processed image.
|
metadata for each processed image.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
@ -586,9 +584,8 @@ class Predictor(BasePredictor):
|
||||||
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple):
|
new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||||
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||||
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
||||||
|
|
@ -690,10 +687,8 @@ class SAM2Predictor(Predictor):
|
||||||
img_idx (int): Index of the image in the batch to process.
|
img_idx (int): Index of the image in the batch to process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Tuple containing:
|
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
(np.ndarray): Quality scores for each mask, with length C.
|
||||||
- np.ndarray: Quality scores for each mask, with length C.
|
|
||||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> predictor = SAM2Predictor(cfg)
|
>>> predictor = SAM2Predictor(cfg)
|
||||||
|
|
@ -712,7 +707,7 @@ class SAM2Predictor(Predictor):
|
||||||
"""
|
"""
|
||||||
features = self.get_im_features(im) if self.features is None else self.features
|
features = self.get_im_features(im) if self.features is None else self.features
|
||||||
|
|
||||||
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||||
points = (points, labels) if points is not None else None
|
points = (points, labels) if points is not None else None
|
||||||
|
|
||||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||||
|
|
@ -751,7 +746,7 @@ class SAM2Predictor(Predictor):
|
||||||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
(tuple): A tuple containing transformed points, labels, and masks.
|
||||||
"""
|
"""
|
||||||
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
||||||
if bboxes is not None:
|
if bboxes is not None:
|
||||||
|
|
@ -764,7 +759,7 @@ class SAM2Predictor(Predictor):
|
||||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||||
else:
|
else:
|
||||||
points, labels = bboxes, bbox_labels
|
points, labels = bboxes, bbox_labels
|
||||||
return bboxes, points, labels, masks
|
return points, labels, masks
|
||||||
|
|
||||||
def set_image(self, image):
|
def set_image(self, image):
|
||||||
"""
|
"""
|
||||||
|
|
@ -815,3 +810,797 @@ class SAM2Predictor(Predictor):
|
||||||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||||
][::-1]
|
][::-1]
|
||||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||||
|
|
||||||
|
|
||||||
|
class SAM2VideoPredictor(SAM2Predictor):
|
||||||
|
"""
|
||||||
|
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
||||||
|
|
||||||
|
This class extends the functionality of SAM2Predictor to support video processing and maintains
|
||||||
|
the state of inference operations. It includes configurations for managing non-overlapping masks,
|
||||||
|
clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
inference_state (Dict): A dictionary to store the current state of inference operations.
|
||||||
|
non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
|
||||||
|
clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
|
||||||
|
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
|
||||||
|
callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
|
||||||
|
overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
|
||||||
|
_callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The `fill_hole_area` attribute is defined but not used in the current implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# fill_hole_area = 8 # not used
|
||||||
|
|
||||||
|
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||||
|
"""
|
||||||
|
Initialize the predictor with configuration and optional overrides.
|
||||||
|
|
||||||
|
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
|
||||||
|
specified overrides, and sets up the inference state along with certain flags
|
||||||
|
that control the behavior of the predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (Dict): Configuration dictionary containing default settings.
|
||||||
|
overrides (Dict | None): Dictionary of values to override default configuration.
|
||||||
|
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
||||||
|
>>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640})
|
||||||
|
>>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
|
||||||
|
"""
|
||||||
|
super().__init__(cfg, overrides, _callbacks)
|
||||||
|
self.inference_state = {}
|
||||||
|
self.non_overlap_masks = True
|
||||||
|
self.clear_non_cond_mem_around_input = False
|
||||||
|
self.clear_non_cond_mem_for_multi_obj = False
|
||||||
|
self.callbacks["on_predict_start"].append(self.init_state)
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
"""
|
||||||
|
Retrieves and configures the model with binarization enabled.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method overrides the base class implementation to set the binarize flag to True.
|
||||||
|
"""
|
||||||
|
model = super().get_model()
|
||||||
|
model.set_binarize(True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
|
||||||
|
"""
|
||||||
|
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||||
|
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||||
|
mask decoder for real-time and promptable segmentation tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||||
|
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||||
|
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||||
|
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||||
|
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
|
||||||
|
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||||
|
"""
|
||||||
|
# Override prompts if any stored in self.prompts
|
||||||
|
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||||
|
points = self.prompts.pop("points", points)
|
||||||
|
masks = self.prompts.pop("masks", masks)
|
||||||
|
|
||||||
|
frame = self.dataset.frame
|
||||||
|
self.inference_state["im"] = im
|
||||||
|
output_dict = self.inference_state["output_dict"]
|
||||||
|
if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
|
||||||
|
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||||
|
if points is not None:
|
||||||
|
for i in range(len(points)):
|
||||||
|
self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
|
||||||
|
elif masks is not None:
|
||||||
|
for i in range(len(masks)):
|
||||||
|
self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
|
||||||
|
self.propagate_in_video_preflight()
|
||||||
|
|
||||||
|
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||||
|
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||||
|
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||||
|
raise RuntimeError("No points are provided; please add points first")
|
||||||
|
|
||||||
|
if frame in consolidated_frame_inds["cond_frame_outputs"]:
|
||||||
|
storage_key = "cond_frame_outputs"
|
||||||
|
current_out = output_dict[storage_key][frame]
|
||||||
|
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||||
|
# clear non-conditioning memory of the surrounding frames
|
||||||
|
self._clear_non_cond_mem_around_input(frame)
|
||||||
|
elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||||
|
storage_key = "non_cond_frame_outputs"
|
||||||
|
current_out = output_dict[storage_key][frame]
|
||||||
|
else:
|
||||||
|
storage_key = "non_cond_frame_outputs"
|
||||||
|
current_out = self._run_single_frame_inference(
|
||||||
|
output_dict=output_dict,
|
||||||
|
frame_idx=frame,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_init_cond_frame=False,
|
||||||
|
point_inputs=None,
|
||||||
|
mask_inputs=None,
|
||||||
|
reverse=False,
|
||||||
|
run_mem_encoder=True,
|
||||||
|
)
|
||||||
|
output_dict[storage_key][frame] = current_out
|
||||||
|
# Create slices of per-object outputs for subsequent interaction with each
|
||||||
|
# individual object after tracking.
|
||||||
|
self._add_output_per_object(frame, current_out, storage_key)
|
||||||
|
self.inference_state["frames_already_tracked"].append(frame)
|
||||||
|
pred_masks = current_out["pred_masks"].flatten(0, 1)
|
||||||
|
pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
|
||||||
|
|
||||||
|
return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
|
||||||
|
|
||||||
|
def postprocess(self, preds, img, orig_imgs):
|
||||||
|
"""
|
||||||
|
Post-processes the predictions to apply non-overlapping constraints if required.
|
||||||
|
|
||||||
|
This method extends the post-processing functionality by applying non-overlapping constraints
|
||||||
|
to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
|
||||||
|
the masks do not overlap, which can be useful for certain applications.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preds (Tuple[torch.Tensor]): The predictions from the model.
|
||||||
|
img (torch.Tensor): The processed image tensor.
|
||||||
|
orig_imgs (List[np.ndarray]): The original images before processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results (list): The post-processed predictions.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
|
||||||
|
"""
|
||||||
|
results = super().postprocess(preds, img, orig_imgs)
|
||||||
|
if self.non_overlap_masks:
|
||||||
|
for result in results:
|
||||||
|
if result.masks is None or len(result.masks) == 0:
|
||||||
|
continue
|
||||||
|
result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
|
||||||
|
return results
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
|
def add_new_prompts(
|
||||||
|
self,
|
||||||
|
obj_id,
|
||||||
|
points=None,
|
||||||
|
labels=None,
|
||||||
|
masks=None,
|
||||||
|
frame_idx=0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds new points or masks to a specific frame for a given object ID.
|
||||||
|
|
||||||
|
This method updates the inference state with new prompts (points or masks) for a specified
|
||||||
|
object and frame index. It ensures that the prompts are either points or masks, but not both,
|
||||||
|
and updates the internal state accordingly. It also handles the generation of new segmentations
|
||||||
|
based on the provided prompts and the existing state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj_id (int): The ID of the object to which the prompts are associated.
|
||||||
|
points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
|
||||||
|
labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
|
||||||
|
masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
|
||||||
|
frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If both `masks` and `points` are provided, or neither is provided.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Only one type of prompt (either points or masks) can be added per call.
|
||||||
|
- If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
|
||||||
|
- The method handles the consolidation of outputs and resizing of masks to the original video resolution.
|
||||||
|
"""
|
||||||
|
assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
|
||||||
|
obj_idx = self._obj_id_to_idx(obj_id)
|
||||||
|
|
||||||
|
point_inputs = None
|
||||||
|
pop_key = "point_inputs_per_obj"
|
||||||
|
if points is not None:
|
||||||
|
point_inputs = {"point_coords": points, "point_labels": labels}
|
||||||
|
self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
|
||||||
|
pop_key = "mask_inputs_per_obj"
|
||||||
|
self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
|
||||||
|
self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
|
||||||
|
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
||||||
|
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||||
|
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||||
|
# the input points will be used to correct the already tracked masks.
|
||||||
|
is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
|
||||||
|
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||||
|
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
|
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||||
|
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
||||||
|
is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
|
||||||
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||||
|
|
||||||
|
# Get any previously predicted mask logits on this object and feed it along with
|
||||||
|
# the new clicks into the SAM mask decoder.
|
||||||
|
prev_sam_mask_logits = None
|
||||||
|
# lookup temporary output dict first, which contains the most recent output
|
||||||
|
# (if not found, then lookup conditioning and non-conditioning frame output)
|
||||||
|
if point_inputs is not None:
|
||||||
|
prev_out = (
|
||||||
|
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||||
|
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||||
|
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prev_out is not None and prev_out.get("pred_masks") is not None:
|
||||||
|
prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
|
||||||
|
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||||
|
prev_sam_mask_logits.clamp_(-32.0, 32.0)
|
||||||
|
current_out = self._run_single_frame_inference(
|
||||||
|
output_dict=obj_output_dict, # run on the slice of a single object
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
batch_size=1, # run on the slice of a single object
|
||||||
|
is_init_cond_frame=is_init_cond_frame,
|
||||||
|
point_inputs=point_inputs,
|
||||||
|
mask_inputs=masks,
|
||||||
|
reverse=False,
|
||||||
|
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
||||||
|
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
||||||
|
# allows us to enforce non-overlapping constraints on all objects before encoding
|
||||||
|
# them into memory.
|
||||||
|
run_mem_encoder=False,
|
||||||
|
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||||
|
)
|
||||||
|
# Add the output to the output dict (to be used as future memory)
|
||||||
|
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
||||||
|
|
||||||
|
# Resize the output mask to the original video resolution
|
||||||
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||||
|
frame_idx,
|
||||||
|
is_cond=is_cond,
|
||||||
|
run_mem_encoder=False,
|
||||||
|
)
|
||||||
|
pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
|
||||||
|
return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
|
||||||
|
|
||||||
|
@smart_inference_mode()
|
||||||
|
def propagate_in_video_preflight(self):
|
||||||
|
"""
|
||||||
|
Prepare inference_state and consolidate temporary outputs before tracking.
|
||||||
|
|
||||||
|
This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
|
||||||
|
It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
|
||||||
|
Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
|
||||||
|
with the provided inputs.
|
||||||
|
"""
|
||||||
|
# Tracking has started and we don't allow adding new objects until session is reset.
|
||||||
|
self.inference_state["tracking_has_started"] = True
|
||||||
|
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||||
|
|
||||||
|
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||||
|
# add them into "output_dict".
|
||||||
|
temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
|
||||||
|
output_dict = self.inference_state["output_dict"]
|
||||||
|
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
||||||
|
# temporary outputs have been added (either in this call or any previous calls
|
||||||
|
# to `propagate_in_video_preflight`).
|
||||||
|
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||||
|
for is_cond in {False, True}:
|
||||||
|
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||||
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||||
|
# Find all the frames that contain temporary outputs for any objects
|
||||||
|
# (these should be the frames that have just received clicks for mask inputs
|
||||||
|
# via `add_new_points` or `add_new_mask`)
|
||||||
|
temp_frame_inds = set()
|
||||||
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||||
|
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||||
|
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||||
|
# consolidate the temprary output across all objects on this frame
|
||||||
|
for frame_idx in temp_frame_inds:
|
||||||
|
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||||
|
frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||||
|
)
|
||||||
|
# merge them into "output_dict" and also create per-object slices
|
||||||
|
output_dict[storage_key][frame_idx] = consolidated_out
|
||||||
|
self._add_output_per_object(frame_idx, consolidated_out, storage_key)
|
||||||
|
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||||
|
# clear non-conditioning memory of the surrounding frames
|
||||||
|
self._clear_non_cond_mem_around_input(frame_idx)
|
||||||
|
|
||||||
|
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||||
|
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||||
|
obj_temp_output_dict[storage_key].clear()
|
||||||
|
|
||||||
|
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||||
|
# output on the same frame in "non_cond_frame_outputs"
|
||||||
|
for frame_idx in output_dict["cond_frame_outputs"]:
|
||||||
|
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||||
|
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||||
|
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||||
|
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||||
|
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||||
|
assert frame_idx in output_dict["cond_frame_outputs"]
|
||||||
|
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||||
|
|
||||||
|
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
||||||
|
# with either points or mask inputs (which should be true under a correct workflow).
|
||||||
|
all_consolidated_frame_inds = (
|
||||||
|
consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
|
||||||
|
)
|
||||||
|
input_frames_inds = set()
|
||||||
|
for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
|
||||||
|
input_frames_inds.update(point_inputs_per_frame.keys())
|
||||||
|
for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
|
||||||
|
input_frames_inds.update(mask_inputs_per_frame.keys())
|
||||||
|
assert all_consolidated_frame_inds == input_frames_inds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_state(predictor):
|
||||||
|
"""
|
||||||
|
Initialize an inference state for the predictor.
|
||||||
|
|
||||||
|
This function sets up the initial state required for performing inference on video data.
|
||||||
|
It includes initializing various dictionaries and ordered dictionaries that will store
|
||||||
|
inputs, outputs, and other metadata relevant to the tracking process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
|
||||||
|
"""
|
||||||
|
if len(predictor.inference_state) > 0: # means initialized
|
||||||
|
return
|
||||||
|
assert predictor.dataset is not None
|
||||||
|
assert predictor.dataset.mode == "video"
|
||||||
|
|
||||||
|
inference_state = {}
|
||||||
|
inference_state["num_frames"] = predictor.dataset.frames
|
||||||
|
# inputs on each frame
|
||||||
|
inference_state["point_inputs_per_obj"] = {}
|
||||||
|
inference_state["mask_inputs_per_obj"] = {}
|
||||||
|
# values that don't change across frames (so we only need to hold one copy of them)
|
||||||
|
inference_state["constants"] = {}
|
||||||
|
# mapping between client-side object id and model-side object index
|
||||||
|
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||||
|
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||||
|
inference_state["obj_ids"] = []
|
||||||
|
# A storage to hold the model's tracking results and states on each frame
|
||||||
|
inference_state["output_dict"] = {
|
||||||
|
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
}
|
||||||
|
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||||
|
inference_state["output_dict_per_obj"] = {}
|
||||||
|
# A temporary storage to hold new outputs when user interact with a frame
|
||||||
|
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
||||||
|
inference_state["temp_output_dict_per_obj"] = {}
|
||||||
|
# Frames that already holds consolidated outputs from click or mask inputs
|
||||||
|
# (we directly use their consolidated outputs during tracking)
|
||||||
|
inference_state["consolidated_frame_inds"] = {
|
||||||
|
"cond_frame_outputs": set(), # set containing frame indices
|
||||||
|
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||||
|
}
|
||||||
|
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||||
|
inference_state["tracking_has_started"] = False
|
||||||
|
inference_state["frames_already_tracked"] = []
|
||||||
|
predictor.inference_state = inference_state
|
||||||
|
|
||||||
|
def get_im_features(self, im, batch=1):
|
||||||
|
"""
|
||||||
|
Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
im (torch.Tensor): The input image tensor.
|
||||||
|
batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
vis_feats (torch.Tensor): The visual features extracted from the image.
|
||||||
|
vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
|
||||||
|
feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- If `batch` is greater than 1, the features are expanded to fit the batch size.
|
||||||
|
- The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
|
||||||
|
"""
|
||||||
|
backbone_out = self.model.forward_image(im)
|
||||||
|
if batch > 1: # expand features if there's more than one prompt
|
||||||
|
for i, feat in enumerate(backbone_out["backbone_fpn"]):
|
||||||
|
backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
|
||||||
|
for i, pos in enumerate(backbone_out["vision_pos_enc"]):
|
||||||
|
pos = pos.expand(batch, -1, -1, -1)
|
||||||
|
backbone_out["vision_pos_enc"][i] = pos
|
||||||
|
_, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
||||||
|
return vis_feats, vis_pos_embed, feat_sizes
|
||||||
|
|
||||||
|
def _obj_id_to_idx(self, obj_id):
|
||||||
|
"""
|
||||||
|
Map client-side object id to model-side object index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj_id (int): The unique identifier of the object provided by the client side.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
obj_idx (int): The index of the object on the model side.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If an attempt is made to add a new object after tracking has started.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The method updates or retrieves mappings between object IDs and indices stored in
|
||||||
|
`inference_state`.
|
||||||
|
- It ensures that new objects can only be added before tracking commences.
|
||||||
|
- It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
|
||||||
|
- Additional data structures are initialized for the new object to store inputs and outputs.
|
||||||
|
"""
|
||||||
|
obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||||
|
if obj_idx is not None:
|
||||||
|
return obj_idx
|
||||||
|
|
||||||
|
# This is a new object id not sent to the server before. We only allow adding
|
||||||
|
# new objects *before* the tracking starts.
|
||||||
|
allow_new_object = not self.inference_state["tracking_has_started"]
|
||||||
|
if allow_new_object:
|
||||||
|
# get the next object slot
|
||||||
|
obj_idx = len(self.inference_state["obj_id_to_idx"])
|
||||||
|
self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
||||||
|
self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
||||||
|
self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
|
||||||
|
# set up input and output structures for this object
|
||||||
|
self.inference_state["point_inputs_per_obj"][obj_idx] = {}
|
||||||
|
self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
||||||
|
self.inference_state["output_dict_per_obj"][obj_idx] = {
|
||||||
|
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
}
|
||||||
|
self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
||||||
|
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||||
|
}
|
||||||
|
return obj_idx
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot add new object id {obj_id} after tracking starts. "
|
||||||
|
f"All existing object ids: {self.inference_state['obj_ids']}. "
|
||||||
|
f"Please call 'reset_state' to restart from scratch."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _run_single_frame_inference(
|
||||||
|
self,
|
||||||
|
output_dict,
|
||||||
|
frame_idx,
|
||||||
|
batch_size,
|
||||||
|
is_init_cond_frame,
|
||||||
|
point_inputs,
|
||||||
|
mask_inputs,
|
||||||
|
reverse,
|
||||||
|
run_mem_encoder,
|
||||||
|
prev_sam_mask_logits=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run tracking on a single frame based on current inputs and previous memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dict (Dict): The dictionary containing the output states of the tracking process.
|
||||||
|
frame_idx (int): The index of the current frame.
|
||||||
|
batch_size (int): The batch size for processing the frame.
|
||||||
|
is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
|
||||||
|
point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
|
||||||
|
mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
|
||||||
|
reverse (bool): Indicates if the tracking should be performed in reverse order.
|
||||||
|
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
|
||||||
|
prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
|
||||||
|
- The method retrieves image features using the `get_im_features` method.
|
||||||
|
- The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
|
||||||
|
- The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
|
||||||
|
"""
|
||||||
|
# Retrieve correct image features
|
||||||
|
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
|
||||||
|
self.inference_state["im"], batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# point and mask should not appear as input simultaneously on the same frame
|
||||||
|
assert point_inputs is None or mask_inputs is None
|
||||||
|
current_out = self.model.track_step(
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
is_init_cond_frame=is_init_cond_frame,
|
||||||
|
current_vision_feats=current_vision_feats,
|
||||||
|
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||||
|
feat_sizes=feat_sizes,
|
||||||
|
point_inputs=point_inputs,
|
||||||
|
mask_inputs=mask_inputs,
|
||||||
|
output_dict=output_dict,
|
||||||
|
num_frames=self.inference_state["num_frames"],
|
||||||
|
track_in_reverse=reverse,
|
||||||
|
run_mem_encoder=run_mem_encoder,
|
||||||
|
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
maskmem_features = current_out["maskmem_features"]
|
||||||
|
if maskmem_features is not None:
|
||||||
|
current_out["maskmem_features"] = maskmem_features.to(
|
||||||
|
dtype=torch.float16, device=self.device, non_blocking=True
|
||||||
|
)
|
||||||
|
# NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
|
||||||
|
# potentially fill holes in the predicted masks
|
||||||
|
# if self.fill_hole_area > 0:
|
||||||
|
# pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
|
||||||
|
# pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
|
||||||
|
|
||||||
|
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||||
|
current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
|
||||||
|
return current_out
|
||||||
|
|
||||||
|
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
||||||
|
"""
|
||||||
|
Caches and manages the positional encoding for mask memory across frames and objects.
|
||||||
|
|
||||||
|
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
|
||||||
|
mask memory, which is constant across frames and objects, thus reducing the amount of
|
||||||
|
redundant information stored during an inference session. It checks if the positional
|
||||||
|
encoding has already been cached; if not, it caches a slice of the provided encoding.
|
||||||
|
If the batch size is greater than one, it expands the cached positional encoding to match
|
||||||
|
the current batch size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
|
||||||
|
Should be a list of tensors or None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
|
||||||
|
- Only a single object's slice is cached since the encoding is the same across objects.
|
||||||
|
- The method checks if the positional encoding has already been cached in the session's constants.
|
||||||
|
- If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
|
||||||
|
"""
|
||||||
|
model_constants = self.inference_state["constants"]
|
||||||
|
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
||||||
|
if out_maskmem_pos_enc is not None:
|
||||||
|
if "maskmem_pos_enc" not in model_constants:
|
||||||
|
assert isinstance(out_maskmem_pos_enc, list)
|
||||||
|
# only take the slice for one object, since it's same across objects
|
||||||
|
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
||||||
|
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
||||||
|
else:
|
||||||
|
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
||||||
|
# expand the cached maskmem_pos_enc to the actual batch size
|
||||||
|
batch_size = out_maskmem_pos_enc[0].size(0)
|
||||||
|
if batch_size > 1:
|
||||||
|
out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
|
||||||
|
return out_maskmem_pos_enc
|
||||||
|
|
||||||
|
def _consolidate_temp_output_across_obj(
|
||||||
|
self,
|
||||||
|
frame_idx,
|
||||||
|
is_cond=False,
|
||||||
|
run_mem_encoder=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Consolidates per-object temporary outputs into a single output for all objects.
|
||||||
|
|
||||||
|
This method combines the temporary outputs for each object on a given frame into a unified
|
||||||
|
output. It fills in any missing objects either from the main output dictionary or leaves
|
||||||
|
placeholders if they do not exist in the main output. Optionally, it can re-run the memory
|
||||||
|
encoder after applying non-overlapping constraints to the object scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx (int): The index of the frame for which to consolidate outputs.
|
||||||
|
is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
|
||||||
|
Defaults to False.
|
||||||
|
run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
|
||||||
|
consolidating the outputs. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- The method initializes the consolidated output with placeholder values for missing objects.
|
||||||
|
- It searches for outputs in both the temporary and main output dictionaries.
|
||||||
|
- If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
|
||||||
|
- The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
|
||||||
|
"""
|
||||||
|
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||||
|
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||||
|
|
||||||
|
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
||||||
|
# will be added when rerunning the memory encoder after applying non-overlapping
|
||||||
|
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||||
|
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||||
|
consolidated_out = {
|
||||||
|
"maskmem_features": None,
|
||||||
|
"maskmem_pos_enc": None,
|
||||||
|
"pred_masks": torch.full(
|
||||||
|
size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
||||||
|
fill_value=-1024.0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"obj_ptr": torch.full(
|
||||||
|
size=(batch_size, self.model.hidden_dim),
|
||||||
|
fill_value=-1024.0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"object_score_logits": torch.full(
|
||||||
|
size=(batch_size, 1),
|
||||||
|
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||||
|
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||||
|
fill_value=10.0,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for obj_idx in range(batch_size):
|
||||||
|
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||||
|
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||||
|
out = (
|
||||||
|
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||||
|
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
||||||
|
# we fall back and look up its previous output in "output_dict_per_obj".
|
||||||
|
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
||||||
|
# "output_dict_per_obj" to find a previous output for this object.
|
||||||
|
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||||
|
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||||
|
)
|
||||||
|
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
||||||
|
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||||
|
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||||
|
if out is None:
|
||||||
|
# Fill in dummy object pointers for those objects without any inputs or
|
||||||
|
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
||||||
|
# i.e. when we need to build the memory for tracking).
|
||||||
|
if run_mem_encoder:
|
||||||
|
# fill object pointer with a dummy pointer (based on an empty mask)
|
||||||
|
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
|
||||||
|
continue
|
||||||
|
# Add the temporary object output mask to consolidated output mask
|
||||||
|
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
|
||||||
|
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||||
|
|
||||||
|
# Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
|
||||||
|
if run_mem_encoder:
|
||||||
|
high_res_masks = F.interpolate(
|
||||||
|
consolidated_out["pred_masks"],
|
||||||
|
size=self.imgsz,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
if self.model.non_overlap_masks_for_mem_enc:
|
||||||
|
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
||||||
|
consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
|
||||||
|
batch_size=batch_size,
|
||||||
|
high_res_masks=high_res_masks,
|
||||||
|
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||||
|
object_score_logits=consolidated_out["object_score_logits"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return consolidated_out
|
||||||
|
|
||||||
|
def _get_empty_mask_ptr(self, frame_idx):
|
||||||
|
"""
|
||||||
|
Get a dummy object pointer based on an empty mask on the current frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
|
||||||
|
"""
|
||||||
|
# Retrieve correct image features
|
||||||
|
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
|
||||||
|
|
||||||
|
# Feed the empty mask and image feature above to get a dummy object pointer
|
||||||
|
current_out = self.model.track_step(
|
||||||
|
frame_idx=frame_idx,
|
||||||
|
is_init_cond_frame=True,
|
||||||
|
current_vision_feats=current_vision_feats,
|
||||||
|
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||||
|
feat_sizes=feat_sizes,
|
||||||
|
point_inputs=None,
|
||||||
|
# A dummy (empty) mask with a single object
|
||||||
|
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
|
||||||
|
output_dict={},
|
||||||
|
num_frames=self.inference_state["num_frames"],
|
||||||
|
track_in_reverse=False,
|
||||||
|
run_mem_encoder=False,
|
||||||
|
prev_sam_mask_logits=None,
|
||||||
|
)
|
||||||
|
return current_out["obj_ptr"]
|
||||||
|
|
||||||
|
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
|
||||||
|
"""
|
||||||
|
Run the memory encoder on masks.
|
||||||
|
|
||||||
|
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
|
||||||
|
memory also needs to be computed again with the memory encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size (int): The batch size for processing the frame.
|
||||||
|
high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
|
||||||
|
object_score_logits (torch.Tensor): Logits representing the object scores.
|
||||||
|
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
|
||||||
|
"""
|
||||||
|
# Retrieve correct image features
|
||||||
|
current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
|
||||||
|
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
||||||
|
current_vision_feats=current_vision_feats,
|
||||||
|
feat_sizes=feat_sizes,
|
||||||
|
pred_masks_high_res=high_res_masks,
|
||||||
|
is_mask_from_pts=is_mask_from_pts,
|
||||||
|
object_score_logits=object_score_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||||
|
maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
|
||||||
|
return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
|
||||||
|
|
||||||
|
def _add_output_per_object(self, frame_idx, current_out, storage_key):
|
||||||
|
"""
|
||||||
|
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
||||||
|
|
||||||
|
The resulting slices share the same tensor storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx (int): The index of the current frame.
|
||||||
|
current_out (Dict): The current output dictionary containing multi-object outputs.
|
||||||
|
storage_key (str): The key used to store the output in the per-object output dictionary.
|
||||||
|
"""
|
||||||
|
maskmem_features = current_out["maskmem_features"]
|
||||||
|
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
||||||
|
|
||||||
|
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||||
|
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
||||||
|
|
||||||
|
for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
|
||||||
|
obj_slice = slice(obj_idx, obj_idx + 1)
|
||||||
|
obj_out = {
|
||||||
|
"maskmem_features": None,
|
||||||
|
"maskmem_pos_enc": None,
|
||||||
|
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||||
|
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||||
|
}
|
||||||
|
if maskmem_features is not None:
|
||||||
|
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||||
|
if maskmem_pos_enc is not None:
|
||||||
|
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||||
|
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||||
|
|
||||||
|
def _clear_non_cond_mem_around_input(self, frame_idx):
|
||||||
|
"""
|
||||||
|
Remove the non-conditioning memory around the input frame.
|
||||||
|
|
||||||
|
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
|
||||||
|
object appearance information and could confuse the model. This method clears those non-conditioning memories
|
||||||
|
surrounding the interacted frame to avoid giving the model both old and new information about the object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_idx (int): The index of the current frame where user interaction occurred.
|
||||||
|
"""
|
||||||
|
r = self.model.memory_temporal_stride_for_eval
|
||||||
|
frame_idx_begin = frame_idx - r * self.model.num_maskmem
|
||||||
|
frame_idx_end = frame_idx + r * self.model.num_maskmem
|
||||||
|
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||||
|
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
||||||
|
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||||
|
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class BaseTrack:
|
||||||
start_frame (int): The frame number where tracking started.
|
start_frame (int): The frame number where tracking started.
|
||||||
frame_id (int): The most recent frame ID processed by the track.
|
frame_id (int): The most recent frame ID processed by the track.
|
||||||
time_since_update (int): Frames passed since the last update.
|
time_since_update (int): Frames passed since the last update.
|
||||||
location (Tuple): The location of the object in the context of multi-camera tracking.
|
location (tuple): The location of the object in the context of multi-camera tracking.
|
||||||
|
|
||||||
Methods:
|
Methods:
|
||||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||||
|
|
|
||||||
|
|
@ -27,10 +27,9 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
|
||||||
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple containing:
|
matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
||||||
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
||||||
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
||||||
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||||
|
|
|
||||||
|
|
@ -552,9 +552,8 @@ class v8PoseLoss(v8DetectionLoss):
|
||||||
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): Returns a tuple containing:
|
kpts_loss (torch.Tensor): The keypoints loss.
|
||||||
- kpts_loss (torch.Tensor): The keypoints loss.
|
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||||
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
|
||||||
"""
|
"""
|
||||||
batch_idx = batch_idx.flatten()
|
batch_idx = batch_idx.flatten()
|
||||||
batch_size = len(masks)
|
batch_size = len(masks)
|
||||||
|
|
|
||||||
|
|
@ -549,19 +549,18 @@ def ap_per_class(
|
||||||
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple of six arrays and one array of unique classes, where:
|
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
|
||||||
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
|
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||||
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
|
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||||
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
|
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||||
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
|
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||||
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
|
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
|
||||||
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
|
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
|
||||||
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
|
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
|
||||||
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
|
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
|
||||||
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
|
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
|
||||||
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
|
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
||||||
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
||||||
prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
|
||||||
"""
|
"""
|
||||||
# Sort by objectness
|
# Sort by objectness
|
||||||
i = np.argsort(-conf)
|
i = np.argsort(-conf)
|
||||||
|
|
|
||||||
|
|
@ -317,11 +317,11 @@ def clip_boxes(boxes, shape):
|
||||||
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
boxes (torch.Tensor): the bounding boxes to clip
|
boxes (torch.Tensor): The bounding boxes to clip.
|
||||||
shape (tuple): the shape of the image
|
shape (tuple): The shape of the image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.Tensor | numpy.ndarray): Clipped boxes
|
(torch.Tensor | numpy.ndarray): The clipped boxes.
|
||||||
"""
|
"""
|
||||||
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||||
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||||
|
|
@ -359,9 +359,9 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
||||||
Takes a mask, and resizes it to the original image size.
|
Takes a mask, and resizes it to the original image size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||||
im0_shape (tuple): the original image shape
|
im0_shape (tuple): The original image shape.
|
||||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
ratio_pad (tuple): The ratio of the padding to the original image.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
|
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
|
||||||
|
|
@ -692,12 +692,12 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
|
||||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
|
||||||
shape (tuple): the size of the input image (h,w)
|
shape (tuple): The size of the input image (h,w).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
masks (torch.Tensor): The returned masks with dimensions [h, w, n].
|
||||||
"""
|
"""
|
||||||
c, mh, mw = protos.shape # CHW
|
c, mh, mw = protos.shape # CHW
|
||||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
||||||
|
|
|
||||||
|
|
@ -584,8 +584,8 @@ class Annotator:
|
||||||
Displays queue counts on an image centered at the points with customizable font size and colors.
|
Displays queue counts on an image centered at the points with customizable font size and colors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label (str): queue counts label
|
label (str): Queue counts label.
|
||||||
points (tuple): region points for center point calculation to display text
|
points (tuple): Region points for center point calculation to display text.
|
||||||
region_color (tuple): RGB queue region color.
|
region_color (tuple): RGB queue region color.
|
||||||
txt_color (tuple): RGB text display color.
|
txt_color (tuple): RGB text display color.
|
||||||
"""
|
"""
|
||||||
|
|
@ -624,13 +624,13 @@ class Annotator:
|
||||||
Display the bounding boxes labels in parking management app.
|
Display the bounding boxes labels in parking management app.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im0 (ndarray): inference image
|
im0 (ndarray): Inference image.
|
||||||
text (str): object/class name
|
text (str): Object/class name.
|
||||||
txt_color (tuple): display color for text foreground
|
txt_color (tuple): Display color for text foreground.
|
||||||
bg_color (tuple): display color for text background
|
bg_color (tuple): Display color for text background.
|
||||||
x_center (float): x position center point for bounding box
|
x_center (float): The x position center point for bounding box.
|
||||||
y_center (float): y position center point for bounding box
|
y_center (float): The y position center point for bounding box.
|
||||||
margin (int): gap between text and rectangle for better display
|
margin (int): The gap between text and rectangle for better display.
|
||||||
"""
|
"""
|
||||||
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
||||||
text_x = x_center - text_size[0] // 2
|
text_x = x_center - text_size[0] // 2
|
||||||
|
|
@ -648,11 +648,11 @@ class Annotator:
|
||||||
Display the overall statistics for parking lots.
|
Display the overall statistics for parking lots.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
im0 (ndarray): inference image
|
im0 (ndarray): Inference image.
|
||||||
text (dict): labels dictionary
|
text (dict): Labels dictionary.
|
||||||
txt_color (tuple): display color for text foreground
|
txt_color (tuple): Display color for text foreground.
|
||||||
bg_color (tuple): display color for text background
|
bg_color (tuple): Display color for text background.
|
||||||
margin (int): gap between text and rectangle for better display
|
margin (int): Gap between text and rectangle for better display.
|
||||||
"""
|
"""
|
||||||
horizontal_gap = int(im0.shape[1] * 0.02)
|
horizontal_gap = int(im0.shape[1] * 0.02)
|
||||||
vertical_gap = int(im0.shape[0] * 0.01)
|
vertical_gap = int(im0.shape[0] * 0.01)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue