diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index 86059422..983d8cdc 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -194,6 +194,34 @@ SAM 2 can be utilized across a broad spectrum of tasks, including real-time vide yolo predict model=sam2.1_b.pt source=path/to/video.mp4 ``` +#### Segment Video and Track objects + +!!! example "Segment Video" + + Segment the entire video content with specific prompts and track objects. + + === "Python" + + ```python + from ultralytics.models.sam import SAM2VideoPredictor + + # Create SAM2VideoPredictor + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model="sam2_b.pt") + predictor = SAM2VideoPredictor(overrides=overrides) + + # Run inference with single point + results = predictor(source="test.mp4", points=[920, 470], labels=1) + + # Run inference with multiple points + results = predictor(source="test.mp4", points=[[920, 470], [909, 138]], labels=[1, 1]) + + # Run inference with multiple points prompt per object + results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 1]]) + + # Run inference with negative points prompt + results = predictor(source="test.mp4", points=[[[920, 470], [909, 138]]], labels=[[1, 0]]) + ``` + - This example demonstrates how SAM 2 can be used to segment the entire content of an image or video if no prompts (bboxes/points/masks) are provided. ## SAM 2 comparison vs YOLOv8 diff --git a/docs/en/reference/models/sam/predict.md b/docs/en/reference/models/sam/predict.md index e715225c..17f8b472 100644 --- a/docs/en/reference/models/sam/predict.md +++ b/docs/en/reference/models/sam/predict.md @@ -17,4 +17,8 @@ keywords: Ultralytics, SAM, Segment Anything Model, SAM 2, Segment Anything Mode ## ::: ultralytics.models.sam.predict.SAM2Predictor +



+ +## ::: ultralytics.models.sam.predict.SAM2VideoPredictor +

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index e17c6c7d..fe22ab07 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.37" +__version__ = "8.3.38" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 9bfd6a60..de9ef96a 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -740,9 +740,8 @@ def parse_key_value_pair(pair: str = "key=value"): pair (str): A string containing a key-value pair in the format "key=value". Returns: - (tuple): A tuple containing two elements: - - key (str): The parsed key. - - value (str): The parsed value. + key (str): The parsed key. + value (str): The parsed value. Raises: AssertionError: If the value is missing or empty. diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index bd821de2..5ec011d8 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -2111,10 +2111,9 @@ class Format: h (int): Height of the image. Returns: - (tuple): Tuple containing: - masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True. - instances (Instances): Updated instances object with sorted segments if mask_overlap is True. - cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True. + masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True. + instances (Instances): Updated instances object with sorted segments if mask_overlap is True. + cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True. Notes: - If self.mask_overlap is True, masks are overlapped and sorted by area. diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py index ead7d613..ae5677cc 100644 --- a/ultralytics/data/loaders.py +++ b/ultralytics/data/loaders.py @@ -354,7 +354,7 @@ class LoadImagesAndVideos: self.nf = ni + nv # number of files self.ni = ni # number of images self.video_flag = [False] * ni + [True] * nv - self.mode = "image" + self.mode = "video" if ni == 0 else "image" # default to video if no images self.vid_stride = vid_stride # video frame-rate stride self.bs = batch if any(videos): diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py index a29f5cb3..30e34236 100644 --- a/ultralytics/models/sam/__init__.py +++ b/ultralytics/models/sam/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license from .model import SAM -from .predict import Predictor, SAM2Predictor +from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor -__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list +__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index e685dc4e..97349a66 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -148,7 +148,7 @@ class SAM(Model): verbose (bool): If True, prints the information to the console. Returns: - (Tuple): A tuple containing the model's information (string representations of the model). + (tuple): A tuple containing the model's information (string representations of the model). Examples: >>> sam = SAM("sam_b.pt") diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 7bfd7166..5d48ed1f 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -36,8 +36,6 @@ class SAMModel(nn.Module): image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. prompt_encoder (PromptEncoder): Encoder for various types of input prompts. mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. - pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1). - pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1). Methods: __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. @@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module): self.sam_prompt_embed_dim = self.hidden_dim self.sam_image_embedding_size = self.image_size // self.backbone_stride - # build PromptEncoder and MaskDecoder from SAM - # (their hyperparameters like `mask_in_chans=16` are from SAM code) + # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code) self.sam_prompt_encoder = PromptEncoder( embed_dim=self.sam_prompt_embed_dim, image_embedding_size=( @@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module): low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. ious: Tensor of shape (B, M) with estimated IoU for each output mask. - low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. - high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. + low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask. + high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask. obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. object_score_logits: Tensor of shape (B,) with object score logits. @@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module): boxes=None, masks=sam_mask_prompt, ) - ( - low_res_multimasks, - ious, - sam_output_tokens, - object_score_logits, - ) = self.sam_mask_decoder( + low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder( image_embeddings=backbone_features, image_pe=self.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, @@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module): if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 - # Mask used for spatial memories is always a *hard* choice between obj and no obj, - # consistent with the actual mask prediction - low_res_multimasks = torch.where( - is_obj_appearing[:, None, None], - low_res_multimasks, - NO_OBJ_SCORE, - ) + # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE) # convert masks from possibly bfloat16 (or float16) to float32 # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) @@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module): def _prepare_backbone_features(self, backbone_out): """Prepares and flattens visual features from the image backbone output for further processing.""" - backbone_out = backbone_out.copy() assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels @@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module): mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc - maskmem_out = self.memory_encoder( - pix_feat, - mask_for_mem, - skip_mask_sigmoid=True, # sigmoid already applied - ) + maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] # add a no-object embedding to the spatial memory to indicate that the frame @@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module): track_in_reverse, prev_sam_mask_logits, ) - - ( - _, - _, - _, - low_res_masks, - high_res_masks, - obj_ptr, - object_score_logits, - ) = sam_outputs + _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks @@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module): # it's mainly used in the demo to encode spatial memories w/ consolidated masks) current_out["object_score_logits"] = object_score_logits - # Finally run the memory encoder on the predicted mask to encode - # it into a new memory feature (that can be used in future frames) + # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames) self._encode_memory_in_output( current_vision_feats, feat_sizes, @@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module): and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) ) - def _apply_non_overlapping_constraints(self, pred_masks): - """Applies non-overlapping constraints to masks, keeping highest scoring object per location.""" + @staticmethod + def _apply_non_overlapping_constraints(pred_masks): + """Applies non-overlapping constraints to masks, keeping the highest scoring object per location.""" batch_size = pred_masks.size(0) if batch_size == 1: return pred_masks @@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module): pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) return pred_masks + def set_binarize(self, binarize=False): + """Set binarize for VideoPredictor.""" + self.binarize_mask_from_pts_for_mem_enc = binarize + def set_imgsz(self, imgsz): """ Set image size to make model compatible with different image sizes. diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index a8315908..540d1007 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -8,6 +8,8 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe segmentation tasks. """ +from collections import OrderedDict + import numpy as np import torch import torch.nn.functional as F @@ -16,7 +18,7 @@ from ultralytics.data.augment import LetterBox from ultralytics.engine.predictor import BasePredictor from ultralytics.engine.results import Results from ultralytics.utils import DEFAULT_CFG, ops -from ultralytics.utils.torch_utils import select_device +from ultralytics.utils.torch_utils import select_device, smart_inference_mode from .amg import ( batch_iterator, @@ -95,7 +97,7 @@ class Predictor(BasePredictor): """ if overrides is None: overrides = {} - overrides.update(dict(task="segment", mode="predict")) + overrides.update(dict(task="segment", mode="predict", batch=1)) super().__init__(cfg, overrides, _callbacks) self.args.retina_masks = True self.im = None @@ -114,7 +116,7 @@ class Predictor(BasePredictor): im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. Returns: - (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. Examples: >>> predictor = Predictor() @@ -181,10 +183,9 @@ class Predictor(BasePredictor): **kwargs (Any): Additional keyword arguments. Returns: - (tuple): Contains the following three elements: - - np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks. - - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - - np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. Examples: >>> predictor = Predictor() @@ -222,10 +223,8 @@ class Predictor(BasePredictor): AssertionError: If the number of points don't match the number of labels, in case labels were passed. Returns: - (tuple): Tuple containing: - - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. - - np.ndarray: Quality scores predicted by the model for each mask, with length C. - - np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256. + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. Examples: >>> predictor = Predictor() @@ -329,10 +328,9 @@ class Predictor(BasePredictor): crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. Returns: - (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: - - pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). - - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). - - pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). Examples: >>> predictor = Predictor() @@ -408,7 +406,7 @@ class Predictor(BasePredictor): return pred_masks, pred_scores, pred_bboxes - def setup_model(self, model, verbose=True): + def setup_model(self, model=None, verbose=True): """ Initializes the Segment Anything Model (SAM) for inference. @@ -416,7 +414,7 @@ class Predictor(BasePredictor): parameters for image normalization and other Ultralytics compatibility settings. Args: - model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration. + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. verbose (bool): If True, prints selected device information. Examples: @@ -459,7 +457,7 @@ class Predictor(BasePredictor): orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. Returns: - (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other metadata for each processed image. Examples: @@ -586,9 +584,8 @@ class Predictor(BasePredictor): nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. Returns: - (tuple): - - new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). - - keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. Examples: >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks @@ -690,10 +687,8 @@ class SAM2Predictor(Predictor): img_idx (int): Index of the image in the batch to process. Returns: - (tuple): Tuple containing: - - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. - - np.ndarray: Quality scores for each mask, with length C. - - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference. + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. Examples: >>> predictor = SAM2Predictor(cfg) @@ -712,7 +707,7 @@ class SAM2Predictor(Predictor): """ features = self.get_im_features(im) if self.features is None else self.features - bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) points = (points, labels) if points is not None else None sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( @@ -751,7 +746,7 @@ class SAM2Predictor(Predictor): AssertionError: If the number of points don't match the number of labels, in case labels were passed. Returns: - (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + (tuple): A tuple containing transformed points, labels, and masks. """ bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) if bboxes is not None: @@ -764,7 +759,7 @@ class SAM2Predictor(Predictor): labels = torch.cat([bbox_labels, labels], dim=1) else: points, labels = bboxes, bbox_labels - return bboxes, points, labels, masks + return points, labels, masks def set_image(self, image): """ @@ -815,3 +810,797 @@ class SAM2Predictor(Predictor): for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) ][::-1] return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = {} + inference_state["num_frames"] = predictor.dataset.frames + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = [] + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py index f3baaf4e..c78ee359 100644 --- a/ultralytics/trackers/basetrack.py +++ b/ultralytics/trackers/basetrack.py @@ -44,7 +44,7 @@ class BaseTrack: start_frame (int): The frame number where tracking started. frame_id (int): The most recent frame ID processed by the track. time_since_update (int): Frames passed since the last update. - location (Tuple): The location of the object in the context of multi-camera tracking. + location (tuple): The location of the object in the context of multi-camera tracking. Methods: end_frame: Returns the ID of the last frame where the object was tracked. diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py index f969f112..b062d938 100644 --- a/ultralytics/trackers/utils/matching.py +++ b/ultralytics/trackers/utils/matching.py @@ -27,10 +27,9 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. Returns: - (tuple): A tuple containing: - - matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. - - unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). - - unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). Examples: >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py index 94038aef..73952868 100644 --- a/ultralytics/utils/loss.py +++ b/ultralytics/utils/loss.py @@ -552,9 +552,8 @@ class v8PoseLoss(v8DetectionLoss): pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim). Returns: - (tuple): Returns a tuple containing: - - kpts_loss (torch.Tensor): The keypoints loss. - - kpts_obj_loss (torch.Tensor): The keypoints object loss. + kpts_loss (torch.Tensor): The keypoints loss. + kpts_obj_loss (torch.Tensor): The keypoints object loss. """ batch_idx = batch_idx.flatten() batch_size = len(masks) diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py index 2b80c02f..bb521f5c 100644 --- a/ultralytics/utils/metrics.py +++ b/ultralytics/utils/metrics.py @@ -549,19 +549,18 @@ def ap_per_class( prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string. Returns: - (tuple): A tuple of six arrays and one array of unique classes, where: - tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,). - fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,). - p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,). - r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,). - f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,). - ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10). - unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,). - p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000). - r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000). - f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000). - x (np.ndarray): X-axis values for the curves. Shape: (1000,). - prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000). + tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,). + fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,). + p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,). + r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,). + f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,). + ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10). + unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,). + p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000). + r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000). + f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000). + x (np.ndarray): X-axis values for the curves. Shape: (1000,). + prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000). """ # Sort by objectness i = np.argsort(-conf) diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py index 07b54b3e..25e83c61 100644 --- a/ultralytics/utils/ops.py +++ b/ultralytics/utils/ops.py @@ -317,11 +317,11 @@ def clip_boxes(boxes, shape): Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. Args: - boxes (torch.Tensor): the bounding boxes to clip - shape (tuple): the shape of the image + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. Returns: - (torch.Tensor | numpy.ndarray): Clipped boxes + (torch.Tensor | numpy.ndarray): The clipped boxes. """ if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 @@ -359,9 +359,9 @@ def scale_image(masks, im0_shape, ratio_pad=None): Takes a mask, and resizes it to the original image size. Args: - masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3]. - im0_shape (tuple): the original image shape - ratio_pad (tuple): the ratio of the padding to the original image. + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. Returns: masks (np.ndarray): The masks that are being returned with shape [h, w, num]. @@ -692,12 +692,12 @@ def process_mask_native(protos, masks_in, bboxes, shape): Args: protos (torch.Tensor): [mask_dim, mask_h, mask_w] - masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms - bboxes (torch.Tensor): [n, 4], n is number of masks after nms - shape (tuple): the size of the input image (h,w) + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). Returns: - masks (torch.Tensor): The returned masks with dimensions [h, w, n] + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. """ c, mh, mw = protos.shape # CHW masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py index 6e257634..f4514247 100644 --- a/ultralytics/utils/plotting.py +++ b/ultralytics/utils/plotting.py @@ -584,8 +584,8 @@ class Annotator: Displays queue counts on an image centered at the points with customizable font size and colors. Args: - label (str): queue counts label - points (tuple): region points for center point calculation to display text + label (str): Queue counts label. + points (tuple): Region points for center point calculation to display text. region_color (tuple): RGB queue region color. txt_color (tuple): RGB text display color. """ @@ -624,13 +624,13 @@ class Annotator: Display the bounding boxes labels in parking management app. Args: - im0 (ndarray): inference image - text (str): object/class name - txt_color (tuple): display color for text foreground - bg_color (tuple): display color for text background - x_center (float): x position center point for bounding box - y_center (float): y position center point for bounding box - margin (int): gap between text and rectangle for better display + im0 (ndarray): Inference image. + text (str): Object/class name. + txt_color (tuple): Display color for text foreground. + bg_color (tuple): Display color for text background. + x_center (float): The x position center point for bounding box. + y_center (float): The y position center point for bounding box. + margin (int): The gap between text and rectangle for better display. """ text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] text_x = x_center - text_size[0] // 2 @@ -648,11 +648,11 @@ class Annotator: Display the overall statistics for parking lots. Args: - im0 (ndarray): inference image - text (dict): labels dictionary - txt_color (tuple): display color for text foreground - bg_color (tuple): display color for text background - margin (int): gap between text and rectangle for better display + im0 (ndarray): Inference image. + text (dict): Labels dictionary. + txt_color (tuple): Display color for text foreground. + bg_color (tuple): Display color for text background. + margin (int): Gap between text and rectangle for better display. """ horizontal_gap = int(im0.shape[1] * 0.02) vertical_gap = int(im0.shape[0] * 0.01)