ultralytics 8.3.38 SAM 2 video inference (#14851)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
407815cf9e
commit
dcc9bd536f
16 changed files with 917 additions and 124 deletions
|
|
@ -8,6 +8,8 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
|
|||
segmentation tasks.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -16,7 +18,7 @@ from ultralytics.data.augment import LetterBox
|
|||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
from .amg import (
|
||||
batch_iterator,
|
||||
|
|
@ -95,7 +97,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task="segment", mode="predict"))
|
||||
overrides.update(dict(task="segment", mode="predict", batch=1))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.retina_masks = True
|
||||
self.im = None
|
||||
|
|
@ -114,7 +116,7 @@ class Predictor(BasePredictor):
|
|||
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -181,10 +183,9 @@ class Predictor(BasePredictor):
|
|||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements:
|
||||
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
(np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||
(np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -222,10 +223,8 @@ class Predictor(BasePredictor):
|
|||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): Quality scores predicted by the model for each mask, with length C.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -329,10 +328,9 @@ class Predictor(BasePredictor):
|
|||
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -408,7 +406,7 @@ class Predictor(BasePredictor):
|
|||
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
def setup_model(self, model=None, verbose=True):
|
||||
"""
|
||||
Initializes the Segment Anything Model (SAM) for inference.
|
||||
|
||||
|
|
@ -416,7 +414,7 @@ class Predictor(BasePredictor):
|
|||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Examples:
|
||||
|
|
@ -459,7 +457,7 @@ class Predictor(BasePredictor):
|
|||
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
metadata for each processed image.
|
||||
|
||||
Examples:
|
||||
|
|
@ -586,9 +584,8 @@ class Predictor(BasePredictor):
|
|||
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
||||
|
||||
Returns:
|
||||
(tuple):
|
||||
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
||||
|
|
@ -690,10 +687,8 @@ class SAM2Predictor(Predictor):
|
|||
img_idx (int): Index of the image in the batch to process.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
||||
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): Quality scores for each mask, with length C.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
|
|
@ -712,7 +707,7 @@ class SAM2Predictor(Predictor):
|
|||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
|
|
@ -751,7 +746,7 @@ class SAM2Predictor(Predictor):
|
|||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||
(tuple): A tuple containing transformed points, labels, and masks.
|
||||
"""
|
||||
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
||||
if bboxes is not None:
|
||||
|
|
@ -764,7 +759,7 @@ class SAM2Predictor(Predictor):
|
|||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
return bboxes, points, labels, masks
|
||||
return points, labels, masks
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
|
|
@ -815,3 +810,797 @@ class SAM2Predictor(Predictor):
|
|||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||
][::-1]
|
||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||
|
||||
|
||||
class SAM2VideoPredictor(SAM2Predictor):
|
||||
"""
|
||||
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
||||
|
||||
This class extends the functionality of SAM2Predictor to support video processing and maintains
|
||||
the state of inference operations. It includes configurations for managing non-overlapping masks,
|
||||
clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
|
||||
|
||||
Attributes:
|
||||
inference_state (Dict): A dictionary to store the current state of inference operations.
|
||||
non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
|
||||
clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
|
||||
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
|
||||
callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
|
||||
|
||||
Args:
|
||||
cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
|
||||
overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
|
||||
_callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
|
||||
|
||||
Note:
|
||||
The `fill_hole_area` attribute is defined but not used in the current implementation.
|
||||
"""
|
||||
|
||||
# fill_hole_area = 8 # not used
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the predictor with configuration and optional overrides.
|
||||
|
||||
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
|
||||
specified overrides, and sets up the inference state along with certain flags
|
||||
that control the behavior of the predictor.
|
||||
|
||||
Args:
|
||||
cfg (Dict): Configuration dictionary containing default settings.
|
||||
overrides (Dict | None): Dictionary of values to override default configuration.
|
||||
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640})
|
||||
>>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.inference_state = {}
|
||||
self.non_overlap_masks = True
|
||||
self.clear_non_cond_mem_around_input = False
|
||||
self.clear_non_cond_mem_for_multi_obj = False
|
||||
self.callbacks["on_predict_start"].append(self.init_state)
|
||||
|
||||
def get_model(self):
|
||||
"""
|
||||
Retrieves and configures the model with binarization enabled.
|
||||
|
||||
Note:
|
||||
This method overrides the base class implementation to set the binarize flag to True.
|
||||
"""
|
||||
model = super().get_model()
|
||||
model.set_binarize(True)
|
||||
return model
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
points = self.prompts.pop("points", points)
|
||||
masks = self.prompts.pop("masks", masks)
|
||||
|
||||
frame = self.dataset.frame
|
||||
self.inference_state["im"] = im
|
||||
output_dict = self.inference_state["output_dict"]
|
||||
if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
|
||||
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
if points is not None:
|
||||
for i in range(len(points)):
|
||||
self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
|
||||
elif masks is not None:
|
||||
for i in range(len(masks)):
|
||||
self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
|
||||
self.propagate_in_video_preflight()
|
||||
|
||||
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
raise RuntimeError("No points are provided; please add points first")
|
||||
|
||||
if frame in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame]
|
||||
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(frame)
|
||||
elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame]
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = self._run_single_frame_inference(
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame,
|
||||
batch_size=batch_size,
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=False,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
output_dict[storage_key][frame] = current_out
|
||||
# Create slices of per-object outputs for subsequent interaction with each
|
||||
# individual object after tracking.
|
||||
self._add_output_per_object(frame, current_out, storage_key)
|
||||
self.inference_state["frames_already_tracked"].append(frame)
|
||||
pred_masks = current_out["pred_masks"].flatten(0, 1)
|
||||
pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
|
||||
|
||||
return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes the predictions to apply non-overlapping constraints if required.
|
||||
|
||||
This method extends the post-processing functionality by applying non-overlapping constraints
|
||||
to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
|
||||
the masks do not overlap, which can be useful for certain applications.
|
||||
|
||||
Args:
|
||||
preds (Tuple[torch.Tensor]): The predictions from the model.
|
||||
img (torch.Tensor): The processed image tensor.
|
||||
orig_imgs (List[np.ndarray]): The original images before processing.
|
||||
|
||||
Returns:
|
||||
results (list): The post-processed predictions.
|
||||
|
||||
Note:
|
||||
If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
|
||||
"""
|
||||
results = super().postprocess(preds, img, orig_imgs)
|
||||
if self.non_overlap_masks:
|
||||
for result in results:
|
||||
if result.masks is None or len(result.masks) == 0:
|
||||
continue
|
||||
result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
|
||||
return results
|
||||
|
||||
@smart_inference_mode()
|
||||
def add_new_prompts(
|
||||
self,
|
||||
obj_id,
|
||||
points=None,
|
||||
labels=None,
|
||||
masks=None,
|
||||
frame_idx=0,
|
||||
):
|
||||
"""
|
||||
Adds new points or masks to a specific frame for a given object ID.
|
||||
|
||||
This method updates the inference state with new prompts (points or masks) for a specified
|
||||
object and frame index. It ensures that the prompts are either points or masks, but not both,
|
||||
and updates the internal state accordingly. It also handles the generation of new segmentations
|
||||
based on the provided prompts and the existing state.
|
||||
|
||||
Args:
|
||||
obj_id (int): The ID of the object to which the prompts are associated.
|
||||
points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
|
||||
labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
|
||||
masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
|
||||
frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `masks` and `points` are provided, or neither is provided.
|
||||
|
||||
Note:
|
||||
- Only one type of prompt (either points or masks) can be added per call.
|
||||
- If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
|
||||
- The method handles the consolidation of outputs and resizing of masks to the original video resolution.
|
||||
"""
|
||||
assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
|
||||
obj_idx = self._obj_id_to_idx(obj_id)
|
||||
|
||||
point_inputs = None
|
||||
pop_key = "point_inputs_per_obj"
|
||||
if points is not None:
|
||||
point_inputs = {"point_coords": points, "point_labels": labels}
|
||||
self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
|
||||
pop_key = "mask_inputs_per_obj"
|
||||
self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
|
||||
self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
|
||||
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
|
||||
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
||||
is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
|
||||
# Get any previously predicted mask logits on this object and feed it along with
|
||||
# the new clicks into the SAM mask decoder.
|
||||
prev_sam_mask_logits = None
|
||||
# lookup temporary output dict first, which contains the most recent output
|
||||
# (if not found, then lookup conditioning and non-conditioning frame output)
|
||||
if point_inputs is not None:
|
||||
prev_out = (
|
||||
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
)
|
||||
|
||||
if prev_out is not None and prev_out.get("pred_masks") is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits.clamp_(-32.0, 32.0)
|
||||
current_out = self._run_single_frame_inference(
|
||||
output_dict=obj_output_dict, # run on the slice of a single object
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=masks,
|
||||
reverse=False,
|
||||
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
||||
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
||||
# allows us to enforce non-overlapping constraints on all objects before encoding
|
||||
# them into memory.
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||
)
|
||||
# Add the output to the output dict (to be used as future memory)
|
||||
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
# Resize the output mask to the original video resolution
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
)
|
||||
pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
|
||||
return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
|
||||
|
||||
@smart_inference_mode()
|
||||
def propagate_in_video_preflight(self):
|
||||
"""
|
||||
Prepare inference_state and consolidate temporary outputs before tracking.
|
||||
|
||||
This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
|
||||
It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
|
||||
Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
|
||||
with the provided inputs.
|
||||
"""
|
||||
# Tracking has started and we don't allow adding new objects until session is reset.
|
||||
self.inference_state["tracking_has_started"] = True
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
|
||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||
# add them into "output_dict".
|
||||
temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
|
||||
output_dict = self.inference_state["output_dict"]
|
||||
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
||||
# temporary outputs have been added (either in this call or any previous calls
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||
for is_cond in {False, True}:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temprary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
)
|
||||
# merge them into "output_dict" and also create per-object slices
|
||||
output_dict[storage_key][frame_idx] = consolidated_out
|
||||
self._add_output_per_object(frame_idx, consolidated_out, storage_key)
|
||||
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(frame_idx)
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
obj_temp_output_dict[storage_key].clear()
|
||||
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in output_dict["cond_frame_outputs"]:
|
||||
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
assert frame_idx in output_dict["cond_frame_outputs"]
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
|
||||
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
||||
# with either points or mask inputs (which should be true under a correct workflow).
|
||||
all_consolidated_frame_inds = (
|
||||
consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
|
||||
)
|
||||
input_frames_inds = set()
|
||||
for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
|
||||
input_frames_inds.update(point_inputs_per_frame.keys())
|
||||
for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
|
||||
input_frames_inds.update(mask_inputs_per_frame.keys())
|
||||
assert all_consolidated_frame_inds == input_frames_inds
|
||||
|
||||
@staticmethod
|
||||
def init_state(predictor):
|
||||
"""
|
||||
Initialize an inference state for the predictor.
|
||||
|
||||
This function sets up the initial state required for performing inference on video data.
|
||||
It includes initializing various dictionaries and ordered dictionaries that will store
|
||||
inputs, outputs, and other metadata relevant to the tracking process.
|
||||
|
||||
Args:
|
||||
predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
|
||||
"""
|
||||
if len(predictor.inference_state) > 0: # means initialized
|
||||
return
|
||||
assert predictor.dataset is not None
|
||||
assert predictor.dataset.mode == "video"
|
||||
|
||||
inference_state = {}
|
||||
inference_state["num_frames"] = predictor.dataset.frames
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
# values that don't change across frames (so we only need to hold one copy of them)
|
||||
inference_state["constants"] = {}
|
||||
# mapping between client-side object id and model-side object index
|
||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||
inference_state["obj_ids"] = []
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
inference_state["output_dict"] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
inference_state["output_dict_per_obj"] = {}
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
||||
inference_state["temp_output_dict_per_obj"] = {}
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
inference_state["consolidated_frame_inds"] = {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
}
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = []
|
||||
predictor.inference_state = inference_state
|
||||
|
||||
def get_im_features(self, im, batch=1):
|
||||
"""
|
||||
Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The input image tensor.
|
||||
batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
vis_feats (torch.Tensor): The visual features extracted from the image.
|
||||
vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
|
||||
feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
|
||||
|
||||
Note:
|
||||
- If `batch` is greater than 1, the features are expanded to fit the batch size.
|
||||
- The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
|
||||
"""
|
||||
backbone_out = self.model.forward_image(im)
|
||||
if batch > 1: # expand features if there's more than one prompt
|
||||
for i, feat in enumerate(backbone_out["backbone_fpn"]):
|
||||
backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
|
||||
for i, pos in enumerate(backbone_out["vision_pos_enc"]):
|
||||
pos = pos.expand(batch, -1, -1, -1)
|
||||
backbone_out["vision_pos_enc"][i] = pos
|
||||
_, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
||||
return vis_feats, vis_pos_embed, feat_sizes
|
||||
|
||||
def _obj_id_to_idx(self, obj_id):
|
||||
"""
|
||||
Map client-side object id to model-side object index.
|
||||
|
||||
Args:
|
||||
obj_id (int): The unique identifier of the object provided by the client side.
|
||||
|
||||
Returns:
|
||||
obj_idx (int): The index of the object on the model side.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an attempt is made to add a new object after tracking has started.
|
||||
|
||||
Note:
|
||||
- The method updates or retrieves mappings between object IDs and indices stored in
|
||||
`inference_state`.
|
||||
- It ensures that new objects can only be added before tracking commences.
|
||||
- It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
|
||||
- Additional data structures are initialized for the new object to store inputs and outputs.
|
||||
"""
|
||||
obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# This is a new object id not sent to the server before. We only allow adding
|
||||
# new objects *before* the tracking starts.
|
||||
allow_new_object = not self.inference_state["tracking_has_started"]
|
||||
if allow_new_object:
|
||||
# get the next object slot
|
||||
obj_idx = len(self.inference_state["obj_id_to_idx"])
|
||||
self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
||||
self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
||||
self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
|
||||
# set up input and output structures for this object
|
||||
self.inference_state["point_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
return obj_idx
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Cannot add new object id {obj_id} after tracking starts. "
|
||||
f"All existing object ids: {self.inference_state['obj_ids']}. "
|
||||
f"Please call 'reset_state' to restart from scratch."
|
||||
)
|
||||
|
||||
def _run_single_frame_inference(
|
||||
self,
|
||||
output_dict,
|
||||
frame_idx,
|
||||
batch_size,
|
||||
is_init_cond_frame,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
reverse,
|
||||
run_mem_encoder,
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
"""
|
||||
Run tracking on a single frame based on current inputs and previous memory.
|
||||
|
||||
Args:
|
||||
output_dict (Dict): The dictionary containing the output states of the tracking process.
|
||||
frame_idx (int): The index of the current frame.
|
||||
batch_size (int): The batch size for processing the frame.
|
||||
is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
|
||||
point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
|
||||
mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
|
||||
reverse (bool): Indicates if the tracking should be performed in reverse order.
|
||||
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
|
||||
prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
|
||||
|
||||
Returns:
|
||||
current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
|
||||
|
||||
Note:
|
||||
- The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
|
||||
- The method retrieves image features using the `get_im_features` method.
|
||||
- The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
|
||||
- The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
|
||||
self.inference_state["im"], batch_size
|
||||
)
|
||||
|
||||
# point and mask should not appear as input simultaneously on the same frame
|
||||
assert point_inputs is None or mask_inputs is None
|
||||
current_out = self.model.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict=output_dict,
|
||||
num_frames=self.inference_state["num_frames"],
|
||||
track_in_reverse=reverse,
|
||||
run_mem_encoder=run_mem_encoder,
|
||||
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
if maskmem_features is not None:
|
||||
current_out["maskmem_features"] = maskmem_features.to(
|
||||
dtype=torch.float16, device=self.device, non_blocking=True
|
||||
)
|
||||
# NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
|
||||
# potentially fill holes in the predicted masks
|
||||
# if self.fill_hole_area > 0:
|
||||
# pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
|
||||
# pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
|
||||
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
|
||||
return current_out
|
||||
|
||||
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
||||
"""
|
||||
Caches and manages the positional encoding for mask memory across frames and objects.
|
||||
|
||||
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
|
||||
mask memory, which is constant across frames and objects, thus reducing the amount of
|
||||
redundant information stored during an inference session. It checks if the positional
|
||||
encoding has already been cached; if not, it caches a slice of the provided encoding.
|
||||
If the batch size is greater than one, it expands the cached positional encoding to match
|
||||
the current batch size.
|
||||
|
||||
Args:
|
||||
out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
|
||||
Should be a list of tensors or None.
|
||||
|
||||
Returns:
|
||||
out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
||||
|
||||
Note:
|
||||
- The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
|
||||
- Only a single object's slice is cached since the encoding is the same across objects.
|
||||
- The method checks if the positional encoding has already been cached in the session's constants.
|
||||
- If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
|
||||
"""
|
||||
model_constants = self.inference_state["constants"]
|
||||
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
||||
if out_maskmem_pos_enc is not None:
|
||||
if "maskmem_pos_enc" not in model_constants:
|
||||
assert isinstance(out_maskmem_pos_enc, list)
|
||||
# only take the slice for one object, since it's same across objects
|
||||
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
||||
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
||||
# expand the cached maskmem_pos_enc to the actual batch size
|
||||
batch_size = out_maskmem_pos_enc[0].size(0)
|
||||
if batch_size > 1:
|
||||
out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
|
||||
return out_maskmem_pos_enc
|
||||
|
||||
def _consolidate_temp_output_across_obj(
|
||||
self,
|
||||
frame_idx,
|
||||
is_cond=False,
|
||||
run_mem_encoder=False,
|
||||
):
|
||||
"""
|
||||
Consolidates per-object temporary outputs into a single output for all objects.
|
||||
|
||||
This method combines the temporary outputs for each object on a given frame into a unified
|
||||
output. It fills in any missing objects either from the main output dictionary or leaves
|
||||
placeholders if they do not exist in the main output. Optionally, it can re-run the memory
|
||||
encoder after applying non-overlapping constraints to the object scores.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the frame for which to consolidate outputs.
|
||||
is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
|
||||
Defaults to False.
|
||||
run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
|
||||
consolidating the outputs. Defaults to False.
|
||||
|
||||
Returns:
|
||||
consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
|
||||
|
||||
Note:
|
||||
- The method initializes the consolidated output with placeholder values for missing objects.
|
||||
- It searches for outputs in both the temporary and main output dictionaries.
|
||||
- If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
|
||||
- The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
|
||||
"""
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
|
||||
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
||||
# will be added when rerunning the memory encoder after applying non-overlapping
|
||||
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||
consolidated_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": torch.full(
|
||||
size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
||||
fill_value=-1024.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
"obj_ptr": torch.full(
|
||||
size=(batch_size, self.model.hidden_dim),
|
||||
fill_value=-1024.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
"object_score_logits": torch.full(
|
||||
size=(batch_size, 1),
|
||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||
fill_value=10.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
for obj_idx in range(batch_size):
|
||||
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||
out = (
|
||||
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
||||
# we fall back and look up its previous output in "output_dict_per_obj".
|
||||
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
||||
# "output_dict_per_obj" to find a previous output for this object.
|
||||
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
)
|
||||
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
||||
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||
if out is None:
|
||||
# Fill in dummy object pointers for those objects without any inputs or
|
||||
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
||||
# i.e. when we need to build the memory for tracking).
|
||||
if run_mem_encoder:
|
||||
# fill object pointer with a dummy pointer (based on an empty mask)
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
|
||||
continue
|
||||
# Add the temporary object output mask to consolidated output mask
|
||||
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||
|
||||
# Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
|
||||
if run_mem_encoder:
|
||||
high_res_masks = F.interpolate(
|
||||
consolidated_out["pred_masks"],
|
||||
size=self.imgsz,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
if self.model.non_overlap_masks_for_mem_enc:
|
||||
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
||||
consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
|
||||
batch_size=batch_size,
|
||||
high_res_masks=high_res_masks,
|
||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||
object_score_logits=consolidated_out["object_score_logits"],
|
||||
)
|
||||
|
||||
return consolidated_out
|
||||
|
||||
def _get_empty_mask_ptr(self, frame_idx):
|
||||
"""
|
||||
Get a dummy object pointer based on an empty mask on the current frame.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
|
||||
|
||||
# Feed the empty mask and image feature above to get a dummy object pointer
|
||||
current_out = self.model.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=True,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=None,
|
||||
# A dummy (empty) mask with a single object
|
||||
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
|
||||
output_dict={},
|
||||
num_frames=self.inference_state["num_frames"],
|
||||
track_in_reverse=False,
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=None,
|
||||
)
|
||||
return current_out["obj_ptr"]
|
||||
|
||||
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
|
||||
"""
|
||||
Run the memory encoder on masks.
|
||||
|
||||
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
|
||||
memory also needs to be computed again with the memory encoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size for processing the frame.
|
||||
high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
|
||||
object_score_logits (torch.Tensor): Logits representing the object scores.
|
||||
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
|
||||
|
||||
Returns:
|
||||
(tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
|
||||
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks,
|
||||
is_mask_from_pts=is_mask_from_pts,
|
||||
object_score_logits=object_score_logits,
|
||||
)
|
||||
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
|
||||
return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
|
||||
|
||||
def _add_output_per_object(self, frame_idx, current_out, storage_key):
|
||||
"""
|
||||
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
||||
|
||||
The resulting slices share the same tensor storage.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame.
|
||||
current_out (Dict): The current output dictionary containing multi-object outputs.
|
||||
storage_key (str): The key used to store the output in the per-object output dictionary.
|
||||
"""
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
||||
|
||||
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
||||
|
||||
for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
|
||||
obj_slice = slice(obj_idx, obj_idx + 1)
|
||||
obj_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||
}
|
||||
if maskmem_features is not None:
|
||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||
if maskmem_pos_enc is not None:
|
||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
def _clear_non_cond_mem_around_input(self, frame_idx):
|
||||
"""
|
||||
Remove the non-conditioning memory around the input frame.
|
||||
|
||||
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
|
||||
object appearance information and could confuse the model. This method clears those non-conditioning memories
|
||||
surrounding the interacted frame to avoid giving the model both old and new information about the object.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame where user interaction occurred.
|
||||
"""
|
||||
r = self.model.memory_temporal_stride_for_eval
|
||||
frame_idx_begin = frame_idx - r * self.model.num_maskmem
|
||||
frame_idx_end = frame_idx + r * self.model.num_maskmem
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
||||
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue