diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md
index 86059422..983d8cdc 100644
--- a/docs/en/models/sam-2.md
+++ b/docs/en/models/sam-2.md
@@ -194,6 +194,34 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide
yolo predict model=sam2.1_b.pt source=path/to/video.mp4
```
+#### Segment Video and Track objects
+
+!!! example "Segment Video"
+
+ Segment the entire video content with specific prompts and track objects.
+
+ === "Python"
+
+ ```python
+ from ultralytics.models.sam import SAM2VideoPredictor
+
+ # Create SAM2VideoPredictor
+ overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt")
+ predictor = SAM2VideoPredictor(overrides=overrides)
+
+ # Run inference with single point
+ results = predictor(source="test.mp4", points=[920, 470], labels=1)
+
+ # Run inference with multiple points
+ results = predictor(source="test.mp4", points=[[920, 470], [909, 138]], labels=[1, 1])
+
+ # Run inference with multiple points prompt per object
+ results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 1]])
+
+ # Run inference with negative points prompt
+ results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 0]])
+ ```
+
- This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided.
## SAM 2 comparison vs YOLOv8
diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md
index e715225c..17f8b472 100644
--- a/docs/en/reference/models/sam/predict.md
+++ b/docs/en/reference/models/sam/predict.md
@@ -17,4 +17,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode
## ::: ultralytics.models.sam.predict.SAM2Predictor
+
+
+## ::: ultralytics.models.sam.predict.SAM2VideoPredictor
+
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index e17c6c7d..fe22ab07 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = "8.3.37"
+__version__ = "8.3.38"
import os
diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py
index 9bfd6a60..de9ef96a 100644
--- a/ultralytics/cfg/__init__.py
+++ b/ultralytics/cfg/__init__.py
@@ -740,9 +740,8 @@ def parse_key_value_pair(pair: str = "key=value"):
pair (str): A string containing a key-value pair in the format "key=value".
Returns:
- (tuple): A tuple containing two elements:
- - key (str): The parsed key.
- - value (str): The parsed value.
+ key (str): The parsed key.
+ value (str): The parsed value.
Raises:
AssertionError: If the value is missing or empty.
diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py
index bd821de2..5ec011d8 100644
--- a/ultralytics/data/augment.py
+++ b/ultralytics/data/augment.py
@@ -2111,10 +2111,9 @@ class Format:
h (int): Height of the image.
Returns:
- (tuple): Tuple containing:
- masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
- instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
- cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
+ masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
+ instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
+ cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
Notes:
- If self.mask_overlap is True, masks are overlapped and sorted by area.
diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py
index ead7d613..ae5677cc 100644
--- a/ultralytics/data/loaders.py
+++ b/ultralytics/data/loaders.py
@@ -354,7 +354,7 @@ class LoadImagesAndVideos:
self.nf = ni + nv # number of files
self.ni = ni # number of images
self.video_flag = [False] * ni + [True] * nv
- self.mode = "image"
+ self.mode = "video" if ni == 0 else "image" # default to video if no images
self.vid_stride = vid_stride # video frame-rate stride
self.bs = batch
if any(videos):
diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py
index a29f5cb3..30e34236 100644
--- a/ultralytics/models/sam/__init__.py
+++ b/ultralytics/models/sam/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from .model import SAM
-from .predict import Predictor, SAM2Predictor
+from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
-__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
+__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index e685dc4e..97349a66 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -148,7 +148,7 @@ class SAM(Model):
verbose (bool): If True, prints the information to the console.
Returns:
- (Tuple): A tuple containing the model's information (string representations of the model).
+ (tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM("sam_b.pt")
diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py
index 7bfd7166..5d48ed1f 100644
--- a/ultralytics/models/sam/modules/sam.py
+++ b/ultralytics/models/sam/modules/sam.py
@@ -36,8 +36,6 @@ class SAMModel(nn.Module):
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
- pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
- pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
Methods:
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
@@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module):
self.sam_prompt_embed_dim = self.hidden_dim
self.sam_image_embedding_size = self.image_size // self.backbone_stride
- # build PromptEncoder and MaskDecoder from SAM
- # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+ # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
self.sam_prompt_encoder = PromptEncoder(
embed_dim=self.sam_prompt_embed_dim,
image_embedding_size=(
@@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module):
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
- low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
- high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
+ low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
+ high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
object_score_logits: Tensor of shape (B,) with object score logits.
@@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module):
boxes=None,
masks=sam_mask_prompt,
)
- (
- low_res_multimasks,
- ious,
- sam_output_tokens,
- object_score_logits,
- ) = self.sam_mask_decoder(
+ low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
@@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module):
if self.pred_obj_scores:
is_obj_appearing = object_score_logits > 0
- # Mask used for spatial memories is always a *hard* choice between obj and no obj,
- # consistent with the actual mask prediction
- low_res_multimasks = torch.where(
- is_obj_appearing[:, None, None],
- low_res_multimasks,
- NO_OBJ_SCORE,
- )
+ # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
+ low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
# convert masks from possibly bfloat16 (or float16) to float32
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
@@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module):
def _prepare_backbone_features(self, backbone_out):
"""Prepares and flattens visual features from the image backbone output for further processing."""
- backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
@@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module):
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
if self.sigmoid_bias_for_mem_enc != 0.0:
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
- maskmem_out = self.memory_encoder(
- pix_feat,
- mask_for_mem,
- skip_mask_sigmoid=True, # sigmoid already applied
- )
+ maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
maskmem_features = maskmem_out["vision_features"]
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
# add a no-object embedding to the spatial memory to indicate that the frame
@@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module):
track_in_reverse,
prev_sam_mask_logits,
)
-
- (
- _,
- _,
- _,
- low_res_masks,
- high_res_masks,
- obj_ptr,
- object_score_logits,
- ) = sam_outputs
+ _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
current_out["pred_masks"] = low_res_masks
current_out["pred_masks_high_res"] = high_res_masks
@@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module):
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
current_out["object_score_logits"] = object_score_logits
- # Finally run the memory encoder on the predicted mask to encode
- # it into a new memory feature (that can be used in future frames)
+ # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
self._encode_memory_in_output(
current_vision_feats,
feat_sizes,
@@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module):
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
)
- def _apply_non_overlapping_constraints(self, pred_masks):
- """Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
+ @staticmethod
+ def _apply_non_overlapping_constraints(pred_masks):
+ """Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
batch_size = pred_masks.size(0)
if batch_size == 1:
return pred_masks
@@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module):
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
return pred_masks
+ def set_binarize(self, binarize=False):
+ """Set binarize for VideoPredictor."""
+ self.binarize_mask_from_pts_for_mem_enc = binarize
+
def set_imgsz(self, imgsz):
"""
Set image size to make model compatible with different image sizes.
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index a8315908..540d1007 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -8,6 +8,8 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
segmentation tasks.
"""
+from collections import OrderedDict
+
import numpy as np
import torch
import torch.nn.functional as F
@@ -16,7 +18,7 @@ from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
-from ultralytics.utils.torch_utils import select_device
+from ultralytics.utils.torch_utils import select_device, smart_inference_mode
from .amg import (
batch_iterator,
@@ -95,7 +97,7 @@ class Predictor(BasePredictor):
"""
if overrides is None:
overrides = {}
- overrides.update(dict(task="segment", mode="predict"))
+ overrides.update(dict(task="segment", mode="predict", batch=1))
super().__init__(cfg, overrides, _callbacks)
self.args.retina_masks = True
self.im = None
@@ -114,7 +116,7 @@ class Predictor(BasePredictor):
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
Returns:
- (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
+ im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
Examples:
>>> predictor = Predictor()
@@ -181,10 +183,9 @@ class Predictor(BasePredictor):
**kwargs (Any): Additional keyword arguments.
Returns:
- (tuple): Contains the following three elements:
- - np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
- - np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- - np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
+ (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
+ (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
+ (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
Examples:
>>> predictor = Predictor()
@@ -222,10 +223,8 @@ class Predictor(BasePredictor):
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
- (tuple): Tuple containing:
- - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- - np.ndarray: Quality scores predicted by the model for each mask, with length C.
- - np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
+ (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
+ (np.ndarray): Quality scores predicted by the model for each mask, with length C.
Examples:
>>> predictor = Predictor()
@@ -329,10 +328,9 @@ class Predictor(BasePredictor):
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
Returns:
- (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
- - pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
- - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
- - pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
+ pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
+ pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
+ pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
Examples:
>>> predictor = Predictor()
@@ -408,7 +406,7 @@ class Predictor(BasePredictor):
return pred_masks, pred_scores, pred_bboxes
- def setup_model(self, model, verbose=True):
+ def setup_model(self, model=None, verbose=True):
"""
Initializes the Segment Anything Model (SAM) for inference.
@@ -416,7 +414,7 @@ class Predictor(BasePredictor):
parameters for image normalization and other Ultralytics compatibility settings.
Args:
- model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
+ model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
verbose (bool): If True, prints selected device information.
Examples:
@@ -459,7 +457,7 @@ class Predictor(BasePredictor):
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
Returns:
- (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
+ results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
metadata for each processed image.
Examples:
@@ -586,9 +584,8 @@ class Predictor(BasePredictor):
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
Returns:
- (tuple):
- - new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
- - keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
+ new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
+ keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
Examples:
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
@@ -690,10 +687,8 @@ class SAM2Predictor(Predictor):
img_idx (int): Index of the image in the batch to process.
Returns:
- (tuple): Tuple containing:
- - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
- - np.ndarray: Quality scores for each mask, with length C.
- - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
+ (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
+ (np.ndarray): Quality scores for each mask, with length C.
Examples:
>>> predictor = SAM2Predictor(cfg)
@@ -712,7 +707,7 @@ class SAM2Predictor(Predictor):
"""
features = self.get_im_features(im) if self.features is None else self.features
- bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+ points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
points = (points, labels) if points is not None else None
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
@@ -751,7 +746,7 @@ class SAM2Predictor(Predictor):
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
Returns:
- (tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
+ (tuple): A tuple containing transformed points, labels, and masks.
"""
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
if bboxes is not None:
@@ -764,7 +759,7 @@ class SAM2Predictor(Predictor):
labels = torch.cat([bbox_labels, labels], dim=1)
else:
points, labels = bboxes, bbox_labels
- return bboxes, points, labels, masks
+ return points, labels, masks
def set_image(self, image):
"""
@@ -815,3 +810,797 @@ class SAM2Predictor(Predictor):
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
][::-1]
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+
+
+class SAM2VideoPredictor(SAM2Predictor):
+ """
+ SAM2VideoPredictor to handle user interactions with videos and manage inference states.
+
+ This class extends the functionality of SAM2Predictor to support video processing and maintains
+ the state of inference operations. It includes configurations for managing non-overlapping masks,
+ clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
+
+ Attributes:
+ inference_state (Dict): A dictionary to store the current state of inference operations.
+ non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
+ clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
+ clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
+ callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
+
+ Args:
+ cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
+ overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
+ _callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
+
+ Note:
+ The `fill_hole_area` attribute is defined but not used in the current implementation.
+ """
+
+ # fill_hole_area = 8 # not used
+
+ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
+ """
+ Initialize the predictor with configuration and optional overrides.
+
+ This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
+ specified overrides, and sets up the inference state along with certain flags
+ that control the behavior of the predictor.
+
+ Args:
+ cfg (Dict): Configuration dictionary containing default settings.
+ overrides (Dict | None): Dictionary of values to override default configuration.
+ _callbacks (Dict | None): Dictionary of callback functions to customize behavior.
+
+ Examples:
+ >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
+ >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640})
+ >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
+ """
+ super().__init__(cfg, overrides, _callbacks)
+ self.inference_state = {}
+ self.non_overlap_masks = True
+ self.clear_non_cond_mem_around_input = False
+ self.clear_non_cond_mem_for_multi_obj = False
+ self.callbacks["on_predict_start"].append(self.init_state)
+
+ def get_model(self):
+ """
+ Retrieves and configures the model with binarization enabled.
+
+ Note:
+ This method overrides the base class implementation to set the binarize flag to True.
+ """
+ model = super().get_model()
+ model.set_binarize(True)
+ return model
+
+ def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
+ """
+ Perform image segmentation inference based on the given input cues, using the currently loaded image. This
+ method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
+ mask decoder for real-time and promptable segmentation tasks.
+
+ Args:
+ im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
+ bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
+ points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
+ labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
+ masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
+
+ Returns:
+ (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
+ (np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
+ """
+ # Override prompts if any stored in self.prompts
+ bboxes = self.prompts.pop("bboxes", bboxes)
+ points = self.prompts.pop("points", points)
+ masks = self.prompts.pop("masks", masks)
+
+ frame = self.dataset.frame
+ self.inference_state["im"] = im
+ output_dict = self.inference_state["output_dict"]
+ if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
+ points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
+ if points is not None:
+ for i in range(len(points)):
+ self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
+ elif masks is not None:
+ for i in range(len(masks)):
+ self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
+ self.propagate_in_video_preflight()
+
+ consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+ batch_size = len(self.inference_state["obj_idx_to_id"])
+ if len(output_dict["cond_frame_outputs"]) == 0:
+ raise RuntimeError("No points are provided; please add points first")
+
+ if frame in consolidated_frame_inds["cond_frame_outputs"]:
+ storage_key = "cond_frame_outputs"
+ current_out = output_dict[storage_key][frame]
+ if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(frame)
+ elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
+ storage_key = "non_cond_frame_outputs"
+ current_out = output_dict[storage_key][frame]
+ else:
+ storage_key = "non_cond_frame_outputs"
+ current_out = self._run_single_frame_inference(
+ output_dict=output_dict,
+ frame_idx=frame,
+ batch_size=batch_size,
+ is_init_cond_frame=False,
+ point_inputs=None,
+ mask_inputs=None,
+ reverse=False,
+ run_mem_encoder=True,
+ )
+ output_dict[storage_key][frame] = current_out
+ # Create slices of per-object outputs for subsequent interaction with each
+ # individual object after tracking.
+ self._add_output_per_object(frame, current_out, storage_key)
+ self.inference_state["frames_already_tracked"].append(frame)
+ pred_masks = current_out["pred_masks"].flatten(0, 1)
+ pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
+
+ return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
+
+ def postprocess(self, preds, img, orig_imgs):
+ """
+ Post-processes the predictions to apply non-overlapping constraints if required.
+
+ This method extends the post-processing functionality by applying non-overlapping constraints
+ to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
+ the masks do not overlap, which can be useful for certain applications.
+
+ Args:
+ preds (Tuple[torch.Tensor]): The predictions from the model.
+ img (torch.Tensor): The processed image tensor.
+ orig_imgs (List[np.ndarray]): The original images before processing.
+
+ Returns:
+ results (list): The post-processed predictions.
+
+ Note:
+ If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
+ """
+ results = super().postprocess(preds, img, orig_imgs)
+ if self.non_overlap_masks:
+ for result in results:
+ if result.masks is None or len(result.masks) == 0:
+ continue
+ result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
+ return results
+
+ @smart_inference_mode()
+ def add_new_prompts(
+ self,
+ obj_id,
+ points=None,
+ labels=None,
+ masks=None,
+ frame_idx=0,
+ ):
+ """
+ Adds new points or masks to a specific frame for a given object ID.
+
+ This method updates the inference state with new prompts (points or masks) for a specified
+ object and frame index. It ensures that the prompts are either points or masks, but not both,
+ and updates the internal state accordingly. It also handles the generation of new segmentations
+ based on the provided prompts and the existing state.
+
+ Args:
+ obj_id (int): The ID of the object to which the prompts are associated.
+ points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
+ labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
+ masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
+ frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
+
+ Returns:
+ (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
+
+ Raises:
+ AssertionError: If both `masks` and `points` are provided, or neither is provided.
+
+ Note:
+ - Only one type of prompt (either points or masks) can be added per call.
+ - If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
+ - The method handles the consolidation of outputs and resizing of masks to the original video resolution.
+ """
+ assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
+ obj_idx = self._obj_id_to_idx(obj_id)
+
+ point_inputs = None
+ pop_key = "point_inputs_per_obj"
+ if points is not None:
+ point_inputs = {"point_coords": points, "point_labels": labels}
+ self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
+ pop_key = "mask_inputs_per_obj"
+ self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
+ self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
+ # If this frame hasn't been tracked before, we treat it as an initial conditioning
+ # frame, meaning that the inputs points are to generate segments on this frame without
+ # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+ # the input points will be used to correct the already tracked masks.
+ is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
+ obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+ obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+ # Add a frame to conditioning output if it's an initial conditioning frame or
+ # if the model sees all frames receiving clicks/mask as conditioning frames.
+ is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Get any previously predicted mask logits on this object and feed it along with
+ # the new clicks into the SAM mask decoder.
+ prev_sam_mask_logits = None
+ # lookup temporary output dict first, which contains the most recent output
+ # (if not found, then lookup conditioning and non-conditioning frame output)
+ if point_inputs is not None:
+ prev_out = (
+ obj_temp_output_dict[storage_key].get(frame_idx)
+ or obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+ )
+
+ if prev_out is not None and prev_out.get("pred_masks") is not None:
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
+ # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+ prev_sam_mask_logits.clamp_(-32.0, 32.0)
+ current_out = self._run_single_frame_inference(
+ output_dict=obj_output_dict, # run on the slice of a single object
+ frame_idx=frame_idx,
+ batch_size=1, # run on the slice of a single object
+ is_init_cond_frame=is_init_cond_frame,
+ point_inputs=point_inputs,
+ mask_inputs=masks,
+ reverse=False,
+ # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+ # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+ # allows us to enforce non-overlapping constraints on all objects before encoding
+ # them into memory.
+ run_mem_encoder=False,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+ # Add the output to the output dict (to be used as future memory)
+ obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+ # Resize the output mask to the original video resolution
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ frame_idx,
+ is_cond=is_cond,
+ run_mem_encoder=False,
+ )
+ pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
+ return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
+
+ @smart_inference_mode()
+ def propagate_in_video_preflight(self):
+ """
+ Prepare inference_state and consolidate temporary outputs before tracking.
+
+ This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
+ It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
+ Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
+ with the provided inputs.
+ """
+ # Tracking has started and we don't allow adding new objects until session is reset.
+ self.inference_state["tracking_has_started"] = True
+ batch_size = len(self.inference_state["obj_idx_to_id"])
+
+ # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+ # add them into "output_dict".
+ temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
+ output_dict = self.inference_state["output_dict"]
+ # "consolidated_frame_inds" contains indices of those frames where consolidated
+ # temporary outputs have been added (either in this call or any previous calls
+ # to `propagate_in_video_preflight`).
+ consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
+ for is_cond in {False, True}:
+ # Separately consolidate conditioning and non-conditioning temp outptus
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+ # Find all the frames that contain temporary outputs for any objects
+ # (these should be the frames that have just received clicks for mask inputs
+ # via `add_new_points` or `add_new_mask`)
+ temp_frame_inds = set()
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+ consolidated_frame_inds[storage_key].update(temp_frame_inds)
+ # consolidate the temprary output across all objects on this frame
+ for frame_idx in temp_frame_inds:
+ consolidated_out = self._consolidate_temp_output_across_obj(
+ frame_idx, is_cond=is_cond, run_mem_encoder=True
+ )
+ # merge them into "output_dict" and also create per-object slices
+ output_dict[storage_key][frame_idx] = consolidated_out
+ self._add_output_per_object(frame_idx, consolidated_out, storage_key)
+ if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
+ # clear non-conditioning memory of the surrounding frames
+ self._clear_non_cond_mem_around_input(frame_idx)
+
+ # clear temporary outputs in `temp_output_dict_per_obj`
+ for obj_temp_output_dict in temp_output_dict_per_obj.values():
+ obj_temp_output_dict[storage_key].clear()
+
+ # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+ # output on the same frame in "non_cond_frame_outputs"
+ for frame_idx in output_dict["cond_frame_outputs"]:
+ output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+ for frame_idx in obj_output_dict["cond_frame_outputs"]:
+ obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+ for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+ assert frame_idx in output_dict["cond_frame_outputs"]
+ consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+ # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+ # with either points or mask inputs (which should be true under a correct workflow).
+ all_consolidated_frame_inds = (
+ consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
+ )
+ input_frames_inds = set()
+ for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
+ input_frames_inds.update(point_inputs_per_frame.keys())
+ for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
+ input_frames_inds.update(mask_inputs_per_frame.keys())
+ assert all_consolidated_frame_inds == input_frames_inds
+
+ @staticmethod
+ def init_state(predictor):
+ """
+ Initialize an inference state for the predictor.
+
+ This function sets up the initial state required for performing inference on video data.
+ It includes initializing various dictionaries and ordered dictionaries that will store
+ inputs, outputs, and other metadata relevant to the tracking process.
+
+ Args:
+ predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
+ """
+ if len(predictor.inference_state) > 0: # means initialized
+ return
+ assert predictor.dataset is not None
+ assert predictor.dataset.mode == "video"
+
+ inference_state = {}
+ inference_state["num_frames"] = predictor.dataset.frames
+ # inputs on each frame
+ inference_state["point_inputs_per_obj"] = {}
+ inference_state["mask_inputs_per_obj"] = {}
+ # values that don't change across frames (so we only need to hold one copy of them)
+ inference_state["constants"] = {}
+ # mapping between client-side object id and model-side object index
+ inference_state["obj_id_to_idx"] = OrderedDict()
+ inference_state["obj_idx_to_id"] = OrderedDict()
+ inference_state["obj_ids"] = []
+ # A storage to hold the model's tracking results and states on each frame
+ inference_state["output_dict"] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+ inference_state["output_dict_per_obj"] = {}
+ # A temporary storage to hold new outputs when user interact with a frame
+ # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+ inference_state["temp_output_dict_per_obj"] = {}
+ # Frames that already holds consolidated outputs from click or mask inputs
+ # (we directly use their consolidated outputs during tracking)
+ inference_state["consolidated_frame_inds"] = {
+ "cond_frame_outputs": set(), # set containing frame indices
+ "non_cond_frame_outputs": set(), # set containing frame indices
+ }
+ # metadata for each tracking frame (e.g. which direction it's tracked)
+ inference_state["tracking_has_started"] = False
+ inference_state["frames_already_tracked"] = []
+ predictor.inference_state = inference_state
+
+ def get_im_features(self, im, batch=1):
+ """
+ Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
+
+ Args:
+ im (torch.Tensor): The input image tensor.
+ batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
+
+ Returns:
+ vis_feats (torch.Tensor): The visual features extracted from the image.
+ vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
+ feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
+
+ Note:
+ - If `batch` is greater than 1, the features are expanded to fit the batch size.
+ - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
+ """
+ backbone_out = self.model.forward_image(im)
+ if batch > 1: # expand features if there's more than one prompt
+ for i, feat in enumerate(backbone_out["backbone_fpn"]):
+ backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
+ for i, pos in enumerate(backbone_out["vision_pos_enc"]):
+ pos = pos.expand(batch, -1, -1, -1)
+ backbone_out["vision_pos_enc"][i] = pos
+ _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
+ return vis_feats, vis_pos_embed, feat_sizes
+
+ def _obj_id_to_idx(self, obj_id):
+ """
+ Map client-side object id to model-side object index.
+
+ Args:
+ obj_id (int): The unique identifier of the object provided by the client side.
+
+ Returns:
+ obj_idx (int): The index of the object on the model side.
+
+ Raises:
+ RuntimeError: If an attempt is made to add a new object after tracking has started.
+
+ Note:
+ - The method updates or retrieves mappings between object IDs and indices stored in
+ `inference_state`.
+ - It ensures that new objects can only be added before tracking commences.
+ - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
+ - Additional data structures are initialized for the new object to store inputs and outputs.
+ """
+ obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
+ if obj_idx is not None:
+ return obj_idx
+
+ # This is a new object id not sent to the server before. We only allow adding
+ # new objects *before* the tracking starts.
+ allow_new_object = not self.inference_state["tracking_has_started"]
+ if allow_new_object:
+ # get the next object slot
+ obj_idx = len(self.inference_state["obj_id_to_idx"])
+ self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
+ self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
+ self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
+ # set up input and output structures for this object
+ self.inference_state["point_inputs_per_obj"][obj_idx] = {}
+ self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
+ self.inference_state["output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
+ "cond_frame_outputs": {}, # dict containing {frame_idx: }
+ "non_cond_frame_outputs": {}, # dict containing {frame_idx: }
+ }
+ return obj_idx
+ else:
+ raise RuntimeError(
+ f"Cannot add new object id {obj_id} after tracking starts. "
+ f"All existing object ids: {self.inference_state['obj_ids']}. "
+ f"Please call 'reset_state' to restart from scratch."
+ )
+
+ def _run_single_frame_inference(
+ self,
+ output_dict,
+ frame_idx,
+ batch_size,
+ is_init_cond_frame,
+ point_inputs,
+ mask_inputs,
+ reverse,
+ run_mem_encoder,
+ prev_sam_mask_logits=None,
+ ):
+ """
+ Run tracking on a single frame based on current inputs and previous memory.
+
+ Args:
+ output_dict (Dict): The dictionary containing the output states of the tracking process.
+ frame_idx (int): The index of the current frame.
+ batch_size (int): The batch size for processing the frame.
+ is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
+ point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
+ mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
+ reverse (bool): Indicates if the tracking should be performed in reverse order.
+ run_mem_encoder (bool): Indicates if the memory encoder should be executed.
+ prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
+
+ Returns:
+ current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
+
+ Raises:
+ AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
+
+ Note:
+ - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
+ - The method retrieves image features using the `get_im_features` method.
+ - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
+ - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
+ """
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
+ self.inference_state["im"], batch_size
+ )
+
+ # point and mask should not appear as input simultaneously on the same frame
+ assert point_inputs is None or mask_inputs is None
+ current_out = self.model.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=is_init_cond_frame,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=point_inputs,
+ mask_inputs=mask_inputs,
+ output_dict=output_dict,
+ num_frames=self.inference_state["num_frames"],
+ track_in_reverse=reverse,
+ run_mem_encoder=run_mem_encoder,
+ prev_sam_mask_logits=prev_sam_mask_logits,
+ )
+
+ maskmem_features = current_out["maskmem_features"]
+ if maskmem_features is not None:
+ current_out["maskmem_features"] = maskmem_features.to(
+ dtype=torch.float16, device=self.device, non_blocking=True
+ )
+ # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
+ # potentially fill holes in the predicted masks
+ # if self.fill_hole_area > 0:
+ # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
+ # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
+
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
+ return current_out
+
+ def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
+ """
+ Caches and manages the positional encoding for mask memory across frames and objects.
+
+ This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
+ mask memory, which is constant across frames and objects, thus reducing the amount of
+ redundant information stored during an inference session. It checks if the positional
+ encoding has already been cached; if not, it caches a slice of the provided encoding.
+ If the batch size is greater than one, it expands the cached positional encoding to match
+ the current batch size.
+
+ Args:
+ out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
+ Should be a list of tensors or None.
+
+ Returns:
+ out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
+
+ Note:
+ - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
+ - Only a single object's slice is cached since the encoding is the same across objects.
+ - The method checks if the positional encoding has already been cached in the session's constants.
+ - If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
+ """
+ model_constants = self.inference_state["constants"]
+ # "out_maskmem_pos_enc" should be either a list of tensors or None
+ if out_maskmem_pos_enc is not None:
+ if "maskmem_pos_enc" not in model_constants:
+ assert isinstance(out_maskmem_pos_enc, list)
+ # only take the slice for one object, since it's same across objects
+ maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+ model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+ else:
+ maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+ # expand the cached maskmem_pos_enc to the actual batch size
+ batch_size = out_maskmem_pos_enc[0].size(0)
+ if batch_size > 1:
+ out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
+ return out_maskmem_pos_enc
+
+ def _consolidate_temp_output_across_obj(
+ self,
+ frame_idx,
+ is_cond=False,
+ run_mem_encoder=False,
+ ):
+ """
+ Consolidates per-object temporary outputs into a single output for all objects.
+
+ This method combines the temporary outputs for each object on a given frame into a unified
+ output. It fills in any missing objects either from the main output dictionary or leaves
+ placeholders if they do not exist in the main output. Optionally, it can re-run the memory
+ encoder after applying non-overlapping constraints to the object scores.
+
+ Args:
+ frame_idx (int): The index of the frame for which to consolidate outputs.
+ is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
+ Defaults to False.
+ run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
+ consolidating the outputs. Defaults to False.
+
+ Returns:
+ consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
+
+ Note:
+ - The method initializes the consolidated output with placeholder values for missing objects.
+ - It searches for outputs in both the temporary and main output dictionaries.
+ - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
+ - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
+ """
+ batch_size = len(self.inference_state["obj_idx_to_id"])
+ storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+ # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+ # will be added when rerunning the memory encoder after applying non-overlapping
+ # constraints to object scores. Its "pred_masks" are prefilled with a large
+ # negative value (NO_OBJ_SCORE) to represent missing objects.
+ consolidated_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": torch.full(
+ size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
+ fill_value=-1024.0,
+ dtype=torch.float32,
+ device=self.device,
+ ),
+ "obj_ptr": torch.full(
+ size=(batch_size, self.model.hidden_dim),
+ fill_value=-1024.0,
+ dtype=torch.float32,
+ device=self.device,
+ ),
+ "object_score_logits": torch.full(
+ size=(batch_size, 1),
+ # default to 10.0 for object_score_logits, i.e. assuming the object is
+ # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
+ fill_value=10.0,
+ dtype=torch.float32,
+ device=self.device,
+ ),
+ }
+ for obj_idx in range(batch_size):
+ obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
+ obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
+ out = (
+ obj_temp_output_dict[storage_key].get(frame_idx)
+ # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+ # we fall back and look up its previous output in "output_dict_per_obj".
+ # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+ # "output_dict_per_obj" to find a previous output for this object.
+ or obj_output_dict["cond_frame_outputs"].get(frame_idx)
+ or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+ )
+ # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+ # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+ # placeholder above) and set its object pointer to be a dummy pointer.
+ if out is None:
+ # Fill in dummy object pointers for those objects without any inputs or
+ # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+ # i.e. when we need to build the memory for tracking).
+ if run_mem_encoder:
+ # fill object pointer with a dummy pointer (based on an empty mask)
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
+ continue
+ # Add the temporary object output mask to consolidated output mask
+ consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
+ consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+
+ # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
+ if run_mem_encoder:
+ high_res_masks = F.interpolate(
+ consolidated_out["pred_masks"],
+ size=self.imgsz,
+ mode="bilinear",
+ align_corners=False,
+ )
+ if self.model.non_overlap_masks_for_mem_enc:
+ high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
+ consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
+ batch_size=batch_size,
+ high_res_masks=high_res_masks,
+ is_mask_from_pts=True, # these frames are what the user interacted with
+ object_score_logits=consolidated_out["object_score_logits"],
+ )
+
+ return consolidated_out
+
+ def _get_empty_mask_ptr(self, frame_idx):
+ """
+ Get a dummy object pointer based on an empty mask on the current frame.
+
+ Args:
+ frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
+
+ Returns:
+ (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
+ """
+ # Retrieve correct image features
+ current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
+
+ # Feed the empty mask and image feature above to get a dummy object pointer
+ current_out = self.model.track_step(
+ frame_idx=frame_idx,
+ is_init_cond_frame=True,
+ current_vision_feats=current_vision_feats,
+ current_vision_pos_embeds=current_vision_pos_embeds,
+ feat_sizes=feat_sizes,
+ point_inputs=None,
+ # A dummy (empty) mask with a single object
+ mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
+ output_dict={},
+ num_frames=self.inference_state["num_frames"],
+ track_in_reverse=False,
+ run_mem_encoder=False,
+ prev_sam_mask_logits=None,
+ )
+ return current_out["obj_ptr"]
+
+ def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
+ """
+ Run the memory encoder on masks.
+
+ This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
+ memory also needs to be computed again with the memory encoder.
+
+ Args:
+ batch_size (int): The batch size for processing the frame.
+ high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
+ object_score_logits (torch.Tensor): Logits representing the object scores.
+ is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
+
+ Returns:
+ (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
+ """
+ # Retrieve correct image features
+ current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
+ maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
+ current_vision_feats=current_vision_feats,
+ feat_sizes=feat_sizes,
+ pred_masks_high_res=high_res_masks,
+ is_mask_from_pts=is_mask_from_pts,
+ object_score_logits=object_score_logits,
+ )
+
+ # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+ maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
+ return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
+
+ def _add_output_per_object(self, frame_idx, current_out, storage_key):
+ """
+ Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
+
+ The resulting slices share the same tensor storage.
+
+ Args:
+ frame_idx (int): The index of the current frame.
+ current_out (Dict): The current output dictionary containing multi-object outputs.
+ storage_key (str): The key used to store the output in the per-object output dictionary.
+ """
+ maskmem_features = current_out["maskmem_features"]
+ assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+ maskmem_pos_enc = current_out["maskmem_pos_enc"]
+ assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+ for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
+ obj_slice = slice(obj_idx, obj_idx + 1)
+ obj_out = {
+ "maskmem_features": None,
+ "maskmem_pos_enc": None,
+ "pred_masks": current_out["pred_masks"][obj_slice],
+ "obj_ptr": current_out["obj_ptr"][obj_slice],
+ }
+ if maskmem_features is not None:
+ obj_out["maskmem_features"] = maskmem_features[obj_slice]
+ if maskmem_pos_enc is not None:
+ obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+ obj_output_dict[storage_key][frame_idx] = obj_out
+
+ def _clear_non_cond_mem_around_input(self, frame_idx):
+ """
+ Remove the non-conditioning memory around the input frame.
+
+ When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
+ object appearance information and could confuse the model. This method clears those non-conditioning memories
+ surrounding the interacted frame to avoid giving the model both old and new information about the object.
+
+ Args:
+ frame_idx (int): The index of the current frame where user interaction occurred.
+ """
+ r = self.model.memory_temporal_stride_for_eval
+ frame_idx_begin = frame_idx - r * self.model.num_maskmem
+ frame_idx_end = frame_idx + r * self.model.num_maskmem
+ for t in range(frame_idx_begin, frame_idx_end + 1):
+ self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
+ for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
+ obj_output_dict["non_cond_frame_outputs"].pop(t, None)
diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py
index f3baaf4e..c78ee359 100644
--- a/ultralytics/trackers/basetrack.py
+++ b/ultralytics/trackers/basetrack.py
@@ -44,7 +44,7 @@ class BaseTrack:
start_frame (int): The frame number where tracking started.
frame_id (int): The most recent frame ID processed by the track.
time_since_update (int): Frames passed since the last update.
- location (Tuple): The location of the object in the context of multi-camera tracking.
+ location (tuple): The location of the object in the context of multi-camera tracking.
Methods:
end_frame: Returns the ID of the last frame where the object was tracked.
diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py
index f969f112..b062d938 100644
--- a/ultralytics/trackers/utils/matching.py
+++ b/ultralytics/trackers/utils/matching.py
@@ -27,10 +27,9 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
Returns:
- (tuple): A tuple containing:
- - matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
- - unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
- - unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
+ matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
+ unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
+ unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
Examples:
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
index 94038aef..73952868 100644
--- a/ultralytics/utils/loss.py
+++ b/ultralytics/utils/loss.py
@@ -552,9 +552,8 @@ class v8PoseLoss(v8DetectionLoss):
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
Returns:
- (tuple): Returns a tuple containing:
- - kpts_loss (torch.Tensor): The keypoints loss.
- - kpts_obj_loss (torch.Tensor): The keypoints object loss.
+ kpts_loss (torch.Tensor): The keypoints loss.
+ kpts_obj_loss (torch.Tensor): The keypoints object loss.
"""
batch_idx = batch_idx.flatten()
batch_size = len(masks)
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
index 2b80c02f..bb521f5c 100644
--- a/ultralytics/utils/metrics.py
+++ b/ultralytics/utils/metrics.py
@@ -549,19 +549,18 @@ def ap_per_class(
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
Returns:
- (tuple): A tuple of six arrays and one array of unique classes, where:
- tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
- fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
- p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
- r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
- f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
- ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
- unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
- p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
- r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
- f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
- x (np.ndarray): X-axis values for the curves. Shape: (1000,).
- prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
+ tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
+ fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
+ p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
+ r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
+ f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
+ ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
+ unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
+ p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
+ r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
+ f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
+ x (np.ndarray): X-axis values for the curves. Shape: (1000,).
+ prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
"""
# Sort by objectness
i = np.argsort(-conf)
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
index 07b54b3e..25e83c61 100644
--- a/ultralytics/utils/ops.py
+++ b/ultralytics/utils/ops.py
@@ -317,11 +317,11 @@ def clip_boxes(boxes, shape):
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
Args:
- boxes (torch.Tensor): the bounding boxes to clip
- shape (tuple): the shape of the image
+ boxes (torch.Tensor): The bounding boxes to clip.
+ shape (tuple): The shape of the image.
Returns:
- (torch.Tensor | numpy.ndarray): Clipped boxes
+ (torch.Tensor | numpy.ndarray): The clipped boxes.
"""
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
@@ -359,9 +359,9 @@ def scale_image(masks, im0_shape, ratio_pad=None):
Takes a mask, and resizes it to the original image size.
Args:
- masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
- im0_shape (tuple): the original image shape
- ratio_pad (tuple): the ratio of the padding to the original image.
+ masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
+ im0_shape (tuple): The original image shape.
+ ratio_pad (tuple): The ratio of the padding to the original image.
Returns:
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
@@ -692,12 +692,12 @@ def process_mask_native(protos, masks_in, bboxes, shape):
Args:
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
- masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
- bboxes (torch.Tensor): [n, 4], n is number of masks after nms
- shape (tuple): the size of the input image (h,w)
+ masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
+ bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
+ shape (tuple): The size of the input image (h,w).
Returns:
- masks (torch.Tensor): The returned masks with dimensions [h, w, n]
+ masks (torch.Tensor): The returned masks with dimensions [h, w, n].
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
index 6e257634..f4514247 100644
--- a/ultralytics/utils/plotting.py
+++ b/ultralytics/utils/plotting.py
@@ -584,8 +584,8 @@ class Annotator:
Displays queue counts on an image centered at the points with customizable font size and colors.
Args:
- label (str): queue counts label
- points (tuple): region points for center point calculation to display text
+ label (str): Queue counts label.
+ points (tuple): Region points for center point calculation to display text.
region_color (tuple): RGB queue region color.
txt_color (tuple): RGB text display color.
"""
@@ -624,13 +624,13 @@ class Annotator:
Display the bounding boxes labels in parking management app.
Args:
- im0 (ndarray): inference image
- text (str): object/class name
- txt_color (tuple): display color for text foreground
- bg_color (tuple): display color for text background
- x_center (float): x position center point for bounding box
- y_center (float): y position center point for bounding box
- margin (int): gap between text and rectangle for better display
+ im0 (ndarray): Inference image.
+ text (str): Object/class name.
+ txt_color (tuple): Display color for text foreground.
+ bg_color (tuple): Display color for text background.
+ x_center (float): The x position center point for bounding box.
+ y_center (float): The y position center point for bounding box.
+ margin (int): The gap between text and rectangle for better display.
"""
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
text_x = x_center - text_size[0] // 2
@@ -648,11 +648,11 @@ class Annotator:
Display the overall statistics for parking lots.
Args:
- im0 (ndarray): inference image
- text (dict): labels dictionary
- txt_color (tuple): display color for text foreground
- bg_color (tuple): display color for text background
- margin (int): gap between text and rectangle for better display
+ im0 (ndarray): Inference image.
+ text (dict): Labels dictionary.
+ txt_color (tuple): Display color for text foreground.
+ bg_color (tuple): Display color for text background.
+ margin (int): Gap between text and rectangle for better display.
"""
horizontal_gap = int(im0.shape[1] * 0.02)
vertical_gap = int(im0.shape[0] * 0.01)