ultralytics 8.3.38 SAM 2 video inference (#14851)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
407815cf9e
commit
dcc9bd536f
16 changed files with 917 additions and 124 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.3.37"
|
||||
__version__ = "8.3.38"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -740,9 +740,8 @@ def parse_key_value_pair(pair: str = "key=value"):
|
|||
pair (str): A string containing a key-value pair in the format "key=value".
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing two elements:
|
||||
- key (str): The parsed key.
|
||||
- value (str): The parsed value.
|
||||
key (str): The parsed key.
|
||||
value (str): The parsed value.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the value is missing or empty.
|
||||
|
|
|
|||
|
|
@ -2111,10 +2111,9 @@ class Format:
|
|||
h (int): Height of the image.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
|
||||
instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
|
||||
cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
|
||||
masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True.
|
||||
instances (Instances): Updated instances object with sorted segments if mask_overlap is True.
|
||||
cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True.
|
||||
|
||||
Notes:
|
||||
- If self.mask_overlap is True, masks are overlapped and sorted by area.
|
||||
|
|
|
|||
|
|
@ -354,7 +354,7 @@ class LoadImagesAndVideos:
|
|||
self.nf = ni + nv # number of files
|
||||
self.ni = ni # number of images
|
||||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = "image"
|
||||
self.mode = "video" if ni == 0 else "image" # default to video if no images
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.bs = batch
|
||||
if any(videos):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .model import SAM
|
||||
from .predict import Predictor, SAM2Predictor
|
||||
from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor
|
||||
|
||||
__all__ = "SAM", "Predictor", "SAM2Predictor" # tuple or list
|
||||
__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list
|
||||
|
|
|
|||
|
|
@ -148,7 +148,7 @@ class SAM(Model):
|
|||
verbose (bool): If True, prints the information to the console.
|
||||
|
||||
Returns:
|
||||
(Tuple): A tuple containing the model's information (string representations of the model).
|
||||
(tuple): A tuple containing the model's information (string representations of the model).
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
|
|
|
|||
|
|
@ -36,8 +36,6 @@ class SAMModel(nn.Module):
|
|||
image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encoder for various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings.
|
||||
pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1).
|
||||
pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1).
|
||||
|
||||
Methods:
|
||||
__init__: Initializes the SAMModel with encoders, decoder, and normalization parameters.
|
||||
|
|
@ -349,8 +347,7 @@ class SAM2Model(torch.nn.Module):
|
|||
self.sam_prompt_embed_dim = self.hidden_dim
|
||||
self.sam_image_embedding_size = self.image_size // self.backbone_stride
|
||||
|
||||
# build PromptEncoder and MaskDecoder from SAM
|
||||
# (their hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
# Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code)
|
||||
self.sam_prompt_encoder = PromptEncoder(
|
||||
embed_dim=self.sam_prompt_embed_dim,
|
||||
image_embedding_size=(
|
||||
|
|
@ -425,8 +422,8 @@ class SAM2Model(torch.nn.Module):
|
|||
low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits.
|
||||
high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits.
|
||||
ious: Tensor of shape (B, M) with estimated IoU for each output mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask.
|
||||
low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask.
|
||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||
object_score_logits: Tensor of shape (B,) with object score logits.
|
||||
|
||||
|
|
@ -488,12 +485,7 @@ class SAM2Model(torch.nn.Module):
|
|||
boxes=None,
|
||||
masks=sam_mask_prompt,
|
||||
)
|
||||
(
|
||||
low_res_multimasks,
|
||||
ious,
|
||||
sam_output_tokens,
|
||||
object_score_logits,
|
||||
) = self.sam_mask_decoder(
|
||||
low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder(
|
||||
image_embeddings=backbone_features,
|
||||
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
|
|
@ -505,13 +497,8 @@ class SAM2Model(torch.nn.Module):
|
|||
if self.pred_obj_scores:
|
||||
is_obj_appearing = object_score_logits > 0
|
||||
|
||||
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
||||
# consistent with the actual mask prediction
|
||||
low_res_multimasks = torch.where(
|
||||
is_obj_appearing[:, None, None],
|
||||
low_res_multimasks,
|
||||
NO_OBJ_SCORE,
|
||||
)
|
||||
# Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction
|
||||
low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE)
|
||||
|
||||
# convert masks from possibly bfloat16 (or float16) to float32
|
||||
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
||||
|
|
@ -617,7 +604,6 @@ class SAM2Model(torch.nn.Module):
|
|||
|
||||
def _prepare_backbone_features(self, backbone_out):
|
||||
"""Prepares and flattens visual features from the image backbone output for further processing."""
|
||||
backbone_out = backbone_out.copy()
|
||||
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
|
||||
assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
|
||||
|
||||
|
|
@ -826,11 +812,7 @@ class SAM2Model(torch.nn.Module):
|
|||
mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
|
||||
if self.sigmoid_bias_for_mem_enc != 0.0:
|
||||
mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
|
||||
maskmem_out = self.memory_encoder(
|
||||
pix_feat,
|
||||
mask_for_mem,
|
||||
skip_mask_sigmoid=True, # sigmoid already applied
|
||||
)
|
||||
maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied
|
||||
maskmem_features = maskmem_out["vision_features"]
|
||||
maskmem_pos_enc = maskmem_out["vision_pos_enc"]
|
||||
# add a no-object embedding to the spatial memory to indicate that the frame
|
||||
|
|
@ -965,16 +947,7 @@ class SAM2Model(torch.nn.Module):
|
|||
track_in_reverse,
|
||||
prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
low_res_masks,
|
||||
high_res_masks,
|
||||
obj_ptr,
|
||||
object_score_logits,
|
||||
) = sam_outputs
|
||||
_, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs
|
||||
|
||||
current_out["pred_masks"] = low_res_masks
|
||||
current_out["pred_masks_high_res"] = high_res_masks
|
||||
|
|
@ -984,8 +957,7 @@ class SAM2Model(torch.nn.Module):
|
|||
# it's mainly used in the demo to encode spatial memories w/ consolidated masks)
|
||||
current_out["object_score_logits"] = object_score_logits
|
||||
|
||||
# Finally run the memory encoder on the predicted mask to encode
|
||||
# it into a new memory feature (that can be used in future frames)
|
||||
# Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames)
|
||||
self._encode_memory_in_output(
|
||||
current_vision_feats,
|
||||
feat_sizes,
|
||||
|
|
@ -1007,8 +979,9 @@ class SAM2Model(torch.nn.Module):
|
|||
and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
|
||||
)
|
||||
|
||||
def _apply_non_overlapping_constraints(self, pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping highest scoring object per location."""
|
||||
@staticmethod
|
||||
def _apply_non_overlapping_constraints(pred_masks):
|
||||
"""Applies non-overlapping constraints to masks, keeping the highest scoring object per location."""
|
||||
batch_size = pred_masks.size(0)
|
||||
if batch_size == 1:
|
||||
return pred_masks
|
||||
|
|
@ -1024,6 +997,10 @@ class SAM2Model(torch.nn.Module):
|
|||
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
|
||||
return pred_masks
|
||||
|
||||
def set_binarize(self, binarize=False):
|
||||
"""Set binarize for VideoPredictor."""
|
||||
self.binarize_mask_from_pts_for_mem_enc = binarize
|
||||
|
||||
def set_imgsz(self, imgsz):
|
||||
"""
|
||||
Set image size to make model compatible with different image sizes.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ using SAM. It forms an integral part of the Ultralytics framework and is designe
|
|||
segmentation tasks.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -16,7 +18,7 @@ from ultralytics.data.augment import LetterBox
|
|||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
from .amg import (
|
||||
batch_iterator,
|
||||
|
|
@ -95,7 +97,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides.update(dict(task="segment", mode="predict"))
|
||||
overrides.update(dict(task="segment", mode="predict", batch=1))
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.retina_masks = True
|
||||
self.im = None
|
||||
|
|
@ -114,7 +116,7 @@ class Predictor(BasePredictor):
|
|||
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -181,10 +183,9 @@ class Predictor(BasePredictor):
|
|||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements:
|
||||
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
(np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||
(np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -222,10 +223,8 @@ class Predictor(BasePredictor):
|
|||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): Quality scores predicted by the model for each mask, with length C.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -329,10 +328,9 @@ class Predictor(BasePredictor):
|
|||
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
|
|
@ -408,7 +406,7 @@ class Predictor(BasePredictor):
|
|||
|
||||
return pred_masks, pred_scores, pred_bboxes
|
||||
|
||||
def setup_model(self, model, verbose=True):
|
||||
def setup_model(self, model=None, verbose=True):
|
||||
"""
|
||||
Initializes the Segment Anything Model (SAM) for inference.
|
||||
|
||||
|
|
@ -416,7 +414,7 @@ class Predictor(BasePredictor):
|
|||
parameters for image normalization and other Ultralytics compatibility settings.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Examples:
|
||||
|
|
@ -459,7 +457,7 @@ class Predictor(BasePredictor):
|
|||
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
metadata for each processed image.
|
||||
|
||||
Examples:
|
||||
|
|
@ -586,9 +584,8 @@ class Predictor(BasePredictor):
|
|||
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
||||
|
||||
Returns:
|
||||
(tuple):
|
||||
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
||||
|
|
@ -690,10 +687,8 @@ class SAM2Predictor(Predictor):
|
|||
img_idx (int): Index of the image in the batch to process.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
||||
(np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
(np.ndarray): Quality scores for each mask, with length C.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
|
|
@ -712,7 +707,7 @@ class SAM2Predictor(Predictor):
|
|||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
|
|
@ -751,7 +746,7 @@ class SAM2Predictor(Predictor):
|
|||
AssertionError: If the number of points don't match the number of labels, in case labels were passed.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing transformed bounding boxes, points, labels, and masks.
|
||||
(tuple): A tuple containing transformed points, labels, and masks.
|
||||
"""
|
||||
bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks)
|
||||
if bboxes is not None:
|
||||
|
|
@ -764,7 +759,7 @@ class SAM2Predictor(Predictor):
|
|||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
return bboxes, points, labels, masks
|
||||
return points, labels, masks
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
|
|
@ -815,3 +810,797 @@ class SAM2Predictor(Predictor):
|
|||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||
][::-1]
|
||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||
|
||||
|
||||
class SAM2VideoPredictor(SAM2Predictor):
|
||||
"""
|
||||
SAM2VideoPredictor to handle user interactions with videos and manage inference states.
|
||||
|
||||
This class extends the functionality of SAM2Predictor to support video processing and maintains
|
||||
the state of inference operations. It includes configurations for managing non-overlapping masks,
|
||||
clearing memory for non-conditional inputs, and setting up callbacks for prediction events.
|
||||
|
||||
Attributes:
|
||||
inference_state (Dict): A dictionary to store the current state of inference operations.
|
||||
non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping.
|
||||
clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs.
|
||||
clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios.
|
||||
callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events.
|
||||
|
||||
Args:
|
||||
cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG.
|
||||
overrides (Dict, Optional): Additional configuration overrides. Defaults to None.
|
||||
_callbacks (List, Optional): Custom callbacks to be added. Defaults to None.
|
||||
|
||||
Note:
|
||||
The `fill_hole_area` attribute is defined but not used in the current implementation.
|
||||
"""
|
||||
|
||||
# fill_hole_area = 8 # not used
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the predictor with configuration and optional overrides.
|
||||
|
||||
This constructor initializes the SAM2VideoPredictor with a given configuration, applies any
|
||||
specified overrides, and sets up the inference state along with certain flags
|
||||
that control the behavior of the predictor.
|
||||
|
||||
Args:
|
||||
cfg (Dict): Configuration dictionary containing default settings.
|
||||
overrides (Dict | None): Dictionary of values to override default configuration.
|
||||
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = SAM2VideoPredictor(overrides={"imgsz": 640})
|
||||
>>> predictor = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback})
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.inference_state = {}
|
||||
self.non_overlap_masks = True
|
||||
self.clear_non_cond_mem_around_input = False
|
||||
self.clear_non_cond_mem_for_multi_obj = False
|
||||
self.callbacks["on_predict_start"].append(self.init_state)
|
||||
|
||||
def get_model(self):
|
||||
"""
|
||||
Retrieves and configures the model with binarization enabled.
|
||||
|
||||
Note:
|
||||
This method overrides the base class implementation to set the binarize flag to True.
|
||||
"""
|
||||
model = super().get_model()
|
||||
model.set_binarize(True)
|
||||
return model
|
||||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
(np.ndarray): An array of length C containing quality scores predicted by the model for each mask.
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
points = self.prompts.pop("points", points)
|
||||
masks = self.prompts.pop("masks", masks)
|
||||
|
||||
frame = self.dataset.frame
|
||||
self.inference_state["im"] = im
|
||||
output_dict = self.inference_state["output_dict"]
|
||||
if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts
|
||||
points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks)
|
||||
if points is not None:
|
||||
for i in range(len(points)):
|
||||
self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame)
|
||||
elif masks is not None:
|
||||
for i in range(len(masks)):
|
||||
self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame)
|
||||
self.propagate_in_video_preflight()
|
||||
|
||||
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
if len(output_dict["cond_frame_outputs"]) == 0:
|
||||
raise RuntimeError("No points are provided; please add points first")
|
||||
|
||||
if frame in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
storage_key = "cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame]
|
||||
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(frame)
|
||||
elif frame in consolidated_frame_inds["non_cond_frame_outputs"]:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = output_dict[storage_key][frame]
|
||||
else:
|
||||
storage_key = "non_cond_frame_outputs"
|
||||
current_out = self._run_single_frame_inference(
|
||||
output_dict=output_dict,
|
||||
frame_idx=frame,
|
||||
batch_size=batch_size,
|
||||
is_init_cond_frame=False,
|
||||
point_inputs=None,
|
||||
mask_inputs=None,
|
||||
reverse=False,
|
||||
run_mem_encoder=True,
|
||||
)
|
||||
output_dict[storage_key][frame] = current_out
|
||||
# Create slices of per-object outputs for subsequent interaction with each
|
||||
# individual object after tracking.
|
||||
self._add_output_per_object(frame, current_out, storage_key)
|
||||
self.inference_state["frames_already_tracked"].append(frame)
|
||||
pred_masks = current_out["pred_masks"].flatten(0, 1)
|
||||
pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks
|
||||
|
||||
return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes the predictions to apply non-overlapping constraints if required.
|
||||
|
||||
This method extends the post-processing functionality by applying non-overlapping constraints
|
||||
to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that
|
||||
the masks do not overlap, which can be useful for certain applications.
|
||||
|
||||
Args:
|
||||
preds (Tuple[torch.Tensor]): The predictions from the model.
|
||||
img (torch.Tensor): The processed image tensor.
|
||||
orig_imgs (List[np.ndarray]): The original images before processing.
|
||||
|
||||
Returns:
|
||||
results (list): The post-processed predictions.
|
||||
|
||||
Note:
|
||||
If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks.
|
||||
"""
|
||||
results = super().postprocess(preds, img, orig_imgs)
|
||||
if self.non_overlap_masks:
|
||||
for result in results:
|
||||
if result.masks is None or len(result.masks) == 0:
|
||||
continue
|
||||
result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0]
|
||||
return results
|
||||
|
||||
@smart_inference_mode()
|
||||
def add_new_prompts(
|
||||
self,
|
||||
obj_id,
|
||||
points=None,
|
||||
labels=None,
|
||||
masks=None,
|
||||
frame_idx=0,
|
||||
):
|
||||
"""
|
||||
Adds new points or masks to a specific frame for a given object ID.
|
||||
|
||||
This method updates the inference state with new prompts (points or masks) for a specified
|
||||
object and frame index. It ensures that the prompts are either points or masks, but not both,
|
||||
and updates the internal state accordingly. It also handles the generation of new segmentations
|
||||
based on the provided prompts and the existing state.
|
||||
|
||||
Args:
|
||||
obj_id (int): The ID of the object to which the prompts are associated.
|
||||
points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None.
|
||||
labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None.
|
||||
masks (torch.Tensor, optional): Binary masks for the object. Defaults to None.
|
||||
frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `masks` and `points` are provided, or neither is provided.
|
||||
|
||||
Note:
|
||||
- Only one type of prompt (either points or masks) can be added per call.
|
||||
- If the frame is being tracked for the first time, it is treated as an initial conditioning frame.
|
||||
- The method handles the consolidation of outputs and resizing of masks to the original video resolution.
|
||||
"""
|
||||
assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other."
|
||||
obj_idx = self._obj_id_to_idx(obj_id)
|
||||
|
||||
point_inputs = None
|
||||
pop_key = "point_inputs_per_obj"
|
||||
if points is not None:
|
||||
point_inputs = {"point_coords": points, "point_labels": labels}
|
||||
self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs
|
||||
pop_key = "mask_inputs_per_obj"
|
||||
self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks
|
||||
self.inference_state[pop_key][obj_idx].pop(frame_idx, None)
|
||||
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
||||
# frame, meaning that the inputs points are to generate segments on this frame without
|
||||
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
||||
# the input points will be used to correct the already tracked masks.
|
||||
is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"]
|
||||
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
# Add a frame to conditioning output if it's an initial conditioning frame or
|
||||
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
||||
is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
|
||||
# Get any previously predicted mask logits on this object and feed it along with
|
||||
# the new clicks into the SAM mask decoder.
|
||||
prev_sam_mask_logits = None
|
||||
# lookup temporary output dict first, which contains the most recent output
|
||||
# (if not found, then lookup conditioning and non-conditioning frame output)
|
||||
if point_inputs is not None:
|
||||
prev_out = (
|
||||
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
)
|
||||
|
||||
if prev_out is not None and prev_out.get("pred_masks") is not None:
|
||||
prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True)
|
||||
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
||||
prev_sam_mask_logits.clamp_(-32.0, 32.0)
|
||||
current_out = self._run_single_frame_inference(
|
||||
output_dict=obj_output_dict, # run on the slice of a single object
|
||||
frame_idx=frame_idx,
|
||||
batch_size=1, # run on the slice of a single object
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=masks,
|
||||
reverse=False,
|
||||
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
||||
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
||||
# allows us to enforce non-overlapping constraints on all objects before encoding
|
||||
# them into memory.
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||
)
|
||||
# Add the output to the output dict (to be used as future memory)
|
||||
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
||||
|
||||
# Resize the output mask to the original video resolution
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
frame_idx,
|
||||
is_cond=is_cond,
|
||||
run_mem_encoder=False,
|
||||
)
|
||||
pred_masks = consolidated_out["pred_masks"].flatten(0, 1)
|
||||
return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device)
|
||||
|
||||
@smart_inference_mode()
|
||||
def propagate_in_video_preflight(self):
|
||||
"""
|
||||
Prepare inference_state and consolidate temporary outputs before tracking.
|
||||
|
||||
This method marks the start of tracking, disallowing the addition of new objects until the session is reset.
|
||||
It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`.
|
||||
Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent
|
||||
with the provided inputs.
|
||||
"""
|
||||
# Tracking has started and we don't allow adding new objects until session is reset.
|
||||
self.inference_state["tracking_has_started"] = True
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
|
||||
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
||||
# add them into "output_dict".
|
||||
temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"]
|
||||
output_dict = self.inference_state["output_dict"]
|
||||
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
||||
# temporary outputs have been added (either in this call or any previous calls
|
||||
# to `propagate_in_video_preflight`).
|
||||
consolidated_frame_inds = self.inference_state["consolidated_frame_inds"]
|
||||
for is_cond in {False, True}:
|
||||
# Separately consolidate conditioning and non-conditioning temp outptus
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
# Find all the frames that contain temporary outputs for any objects
|
||||
# (these should be the frames that have just received clicks for mask inputs
|
||||
# via `add_new_points` or `add_new_mask`)
|
||||
temp_frame_inds = set()
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
||||
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
||||
# consolidate the temprary output across all objects on this frame
|
||||
for frame_idx in temp_frame_inds:
|
||||
consolidated_out = self._consolidate_temp_output_across_obj(
|
||||
frame_idx, is_cond=is_cond, run_mem_encoder=True
|
||||
)
|
||||
# merge them into "output_dict" and also create per-object slices
|
||||
output_dict[storage_key][frame_idx] = consolidated_out
|
||||
self._add_output_per_object(frame_idx, consolidated_out, storage_key)
|
||||
if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1):
|
||||
# clear non-conditioning memory of the surrounding frames
|
||||
self._clear_non_cond_mem_around_input(frame_idx)
|
||||
|
||||
# clear temporary outputs in `temp_output_dict_per_obj`
|
||||
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
||||
obj_temp_output_dict[storage_key].clear()
|
||||
|
||||
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
||||
# output on the same frame in "non_cond_frame_outputs"
|
||||
for frame_idx in output_dict["cond_frame_outputs"]:
|
||||
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
||||
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
||||
assert frame_idx in output_dict["cond_frame_outputs"]
|
||||
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
||||
|
||||
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
||||
# with either points or mask inputs (which should be true under a correct workflow).
|
||||
all_consolidated_frame_inds = (
|
||||
consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"]
|
||||
)
|
||||
input_frames_inds = set()
|
||||
for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values():
|
||||
input_frames_inds.update(point_inputs_per_frame.keys())
|
||||
for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values():
|
||||
input_frames_inds.update(mask_inputs_per_frame.keys())
|
||||
assert all_consolidated_frame_inds == input_frames_inds
|
||||
|
||||
@staticmethod
|
||||
def init_state(predictor):
|
||||
"""
|
||||
Initialize an inference state for the predictor.
|
||||
|
||||
This function sets up the initial state required for performing inference on video data.
|
||||
It includes initializing various dictionaries and ordered dictionaries that will store
|
||||
inputs, outputs, and other metadata relevant to the tracking process.
|
||||
|
||||
Args:
|
||||
predictor (SAM2VideoPredictor): The predictor object for which to initialize the state.
|
||||
"""
|
||||
if len(predictor.inference_state) > 0: # means initialized
|
||||
return
|
||||
assert predictor.dataset is not None
|
||||
assert predictor.dataset.mode == "video"
|
||||
|
||||
inference_state = {}
|
||||
inference_state["num_frames"] = predictor.dataset.frames
|
||||
# inputs on each frame
|
||||
inference_state["point_inputs_per_obj"] = {}
|
||||
inference_state["mask_inputs_per_obj"] = {}
|
||||
# values that don't change across frames (so we only need to hold one copy of them)
|
||||
inference_state["constants"] = {}
|
||||
# mapping between client-side object id and model-side object index
|
||||
inference_state["obj_id_to_idx"] = OrderedDict()
|
||||
inference_state["obj_idx_to_id"] = OrderedDict()
|
||||
inference_state["obj_ids"] = []
|
||||
# A storage to hold the model's tracking results and states on each frame
|
||||
inference_state["output_dict"] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
||||
inference_state["output_dict_per_obj"] = {}
|
||||
# A temporary storage to hold new outputs when user interact with a frame
|
||||
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
||||
inference_state["temp_output_dict_per_obj"] = {}
|
||||
# Frames that already holds consolidated outputs from click or mask inputs
|
||||
# (we directly use their consolidated outputs during tracking)
|
||||
inference_state["consolidated_frame_inds"] = {
|
||||
"cond_frame_outputs": set(), # set containing frame indices
|
||||
"non_cond_frame_outputs": set(), # set containing frame indices
|
||||
}
|
||||
# metadata for each tracking frame (e.g. which direction it's tracked)
|
||||
inference_state["tracking_has_started"] = False
|
||||
inference_state["frames_already_tracked"] = []
|
||||
predictor.inference_state = inference_state
|
||||
|
||||
def get_im_features(self, im, batch=1):
|
||||
"""
|
||||
Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The input image tensor.
|
||||
batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
vis_feats (torch.Tensor): The visual features extracted from the image.
|
||||
vis_pos_embed (torch.Tensor): The positional embeddings for the visual features.
|
||||
feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features.
|
||||
|
||||
Note:
|
||||
- If `batch` is greater than 1, the features are expanded to fit the batch size.
|
||||
- The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features.
|
||||
"""
|
||||
backbone_out = self.model.forward_image(im)
|
||||
if batch > 1: # expand features if there's more than one prompt
|
||||
for i, feat in enumerate(backbone_out["backbone_fpn"]):
|
||||
backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1)
|
||||
for i, pos in enumerate(backbone_out["vision_pos_enc"]):
|
||||
pos = pos.expand(batch, -1, -1, -1)
|
||||
backbone_out["vision_pos_enc"][i] = pos
|
||||
_, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out)
|
||||
return vis_feats, vis_pos_embed, feat_sizes
|
||||
|
||||
def _obj_id_to_idx(self, obj_id):
|
||||
"""
|
||||
Map client-side object id to model-side object index.
|
||||
|
||||
Args:
|
||||
obj_id (int): The unique identifier of the object provided by the client side.
|
||||
|
||||
Returns:
|
||||
obj_idx (int): The index of the object on the model side.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If an attempt is made to add a new object after tracking has started.
|
||||
|
||||
Note:
|
||||
- The method updates or retrieves mappings between object IDs and indices stored in
|
||||
`inference_state`.
|
||||
- It ensures that new objects can only be added before tracking commences.
|
||||
- It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`).
|
||||
- Additional data structures are initialized for the new object to store inputs and outputs.
|
||||
"""
|
||||
obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None)
|
||||
if obj_idx is not None:
|
||||
return obj_idx
|
||||
|
||||
# This is a new object id not sent to the server before. We only allow adding
|
||||
# new objects *before* the tracking starts.
|
||||
allow_new_object = not self.inference_state["tracking_has_started"]
|
||||
if allow_new_object:
|
||||
# get the next object slot
|
||||
obj_idx = len(self.inference_state["obj_id_to_idx"])
|
||||
self.inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
||||
self.inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
||||
self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"])
|
||||
# set up input and output structures for this object
|
||||
self.inference_state["point_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
||||
self.inference_state["output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
self.inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
||||
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
||||
}
|
||||
return obj_idx
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Cannot add new object id {obj_id} after tracking starts. "
|
||||
f"All existing object ids: {self.inference_state['obj_ids']}. "
|
||||
f"Please call 'reset_state' to restart from scratch."
|
||||
)
|
||||
|
||||
def _run_single_frame_inference(
|
||||
self,
|
||||
output_dict,
|
||||
frame_idx,
|
||||
batch_size,
|
||||
is_init_cond_frame,
|
||||
point_inputs,
|
||||
mask_inputs,
|
||||
reverse,
|
||||
run_mem_encoder,
|
||||
prev_sam_mask_logits=None,
|
||||
):
|
||||
"""
|
||||
Run tracking on a single frame based on current inputs and previous memory.
|
||||
|
||||
Args:
|
||||
output_dict (Dict): The dictionary containing the output states of the tracking process.
|
||||
frame_idx (int): The index of the current frame.
|
||||
batch_size (int): The batch size for processing the frame.
|
||||
is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame.
|
||||
point_inputs (Dict, Optional): Input points and their labels. Defaults to None.
|
||||
mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None.
|
||||
reverse (bool): Indicates if the tracking should be performed in reverse order.
|
||||
run_mem_encoder (bool): Indicates if the memory encoder should be executed.
|
||||
prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None.
|
||||
|
||||
Returns:
|
||||
current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions.
|
||||
|
||||
Raises:
|
||||
AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided.
|
||||
|
||||
Note:
|
||||
- The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive.
|
||||
- The method retrieves image features using the `get_im_features` method.
|
||||
- The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored.
|
||||
- The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(
|
||||
self.inference_state["im"], batch_size
|
||||
)
|
||||
|
||||
# point and mask should not appear as input simultaneously on the same frame
|
||||
assert point_inputs is None or mask_inputs is None
|
||||
current_out = self.model.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=is_init_cond_frame,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=point_inputs,
|
||||
mask_inputs=mask_inputs,
|
||||
output_dict=output_dict,
|
||||
num_frames=self.inference_state["num_frames"],
|
||||
track_in_reverse=reverse,
|
||||
run_mem_encoder=run_mem_encoder,
|
||||
prev_sam_mask_logits=prev_sam_mask_logits,
|
||||
)
|
||||
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
if maskmem_features is not None:
|
||||
current_out["maskmem_features"] = maskmem_features.to(
|
||||
dtype=torch.float16, device=self.device, non_blocking=True
|
||||
)
|
||||
# NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions
|
||||
# potentially fill holes in the predicted masks
|
||||
# if self.fill_hole_area > 0:
|
||||
# pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True)
|
||||
# pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area)
|
||||
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"])
|
||||
return current_out
|
||||
|
||||
def _get_maskmem_pos_enc(self, out_maskmem_pos_enc):
|
||||
"""
|
||||
Caches and manages the positional encoding for mask memory across frames and objects.
|
||||
|
||||
This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for
|
||||
mask memory, which is constant across frames and objects, thus reducing the amount of
|
||||
redundant information stored during an inference session. It checks if the positional
|
||||
encoding has already been cached; if not, it caches a slice of the provided encoding.
|
||||
If the batch size is greater than one, it expands the cached positional encoding to match
|
||||
the current batch size.
|
||||
|
||||
Args:
|
||||
out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory.
|
||||
Should be a list of tensors or None.
|
||||
|
||||
Returns:
|
||||
out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded.
|
||||
|
||||
Note:
|
||||
- The method assumes that `out_maskmem_pos_enc` is a list of tensors or None.
|
||||
- Only a single object's slice is cached since the encoding is the same across objects.
|
||||
- The method checks if the positional encoding has already been cached in the session's constants.
|
||||
- If the batch size is greater than one, the cached encoding is expanded to fit the batch size.
|
||||
"""
|
||||
model_constants = self.inference_state["constants"]
|
||||
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
||||
if out_maskmem_pos_enc is not None:
|
||||
if "maskmem_pos_enc" not in model_constants:
|
||||
assert isinstance(out_maskmem_pos_enc, list)
|
||||
# only take the slice for one object, since it's same across objects
|
||||
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
||||
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
||||
else:
|
||||
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
||||
# expand the cached maskmem_pos_enc to the actual batch size
|
||||
batch_size = out_maskmem_pos_enc[0].size(0)
|
||||
if batch_size > 1:
|
||||
out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc]
|
||||
return out_maskmem_pos_enc
|
||||
|
||||
def _consolidate_temp_output_across_obj(
|
||||
self,
|
||||
frame_idx,
|
||||
is_cond=False,
|
||||
run_mem_encoder=False,
|
||||
):
|
||||
"""
|
||||
Consolidates per-object temporary outputs into a single output for all objects.
|
||||
|
||||
This method combines the temporary outputs for each object on a given frame into a unified
|
||||
output. It fills in any missing objects either from the main output dictionary or leaves
|
||||
placeholders if they do not exist in the main output. Optionally, it can re-run the memory
|
||||
encoder after applying non-overlapping constraints to the object scores.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the frame for which to consolidate outputs.
|
||||
is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame.
|
||||
Defaults to False.
|
||||
run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after
|
||||
consolidating the outputs. Defaults to False.
|
||||
|
||||
Returns:
|
||||
consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects.
|
||||
|
||||
Note:
|
||||
- The method initializes the consolidated output with placeholder values for missing objects.
|
||||
- It searches for outputs in both the temporary and main output dictionaries.
|
||||
- If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder.
|
||||
- The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True.
|
||||
"""
|
||||
batch_size = len(self.inference_state["obj_idx_to_id"])
|
||||
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
||||
|
||||
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
||||
# will be added when rerunning the memory encoder after applying non-overlapping
|
||||
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
||||
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
||||
consolidated_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": torch.full(
|
||||
size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4),
|
||||
fill_value=-1024.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
"obj_ptr": torch.full(
|
||||
size=(batch_size, self.model.hidden_dim),
|
||||
fill_value=-1024.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
"object_score_logits": torch.full(
|
||||
size=(batch_size, 1),
|
||||
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
||||
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
||||
fill_value=10.0,
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
for obj_idx in range(batch_size):
|
||||
obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx]
|
||||
obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx]
|
||||
out = (
|
||||
obj_temp_output_dict[storage_key].get(frame_idx)
|
||||
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
||||
# we fall back and look up its previous output in "output_dict_per_obj".
|
||||
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
||||
# "output_dict_per_obj" to find a previous output for this object.
|
||||
or obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
||||
or obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
||||
)
|
||||
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
||||
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
||||
# placeholder above) and set its object pointer to be a dummy pointer.
|
||||
if out is None:
|
||||
# Fill in dummy object pointers for those objects without any inputs or
|
||||
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
||||
# i.e. when we need to build the memory for tracking).
|
||||
if run_mem_encoder:
|
||||
# fill object pointer with a dummy pointer (based on an empty mask)
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx)
|
||||
continue
|
||||
# Add the temporary object output mask to consolidated output mask
|
||||
consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"]
|
||||
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
||||
|
||||
# Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder
|
||||
if run_mem_encoder:
|
||||
high_res_masks = F.interpolate(
|
||||
consolidated_out["pred_masks"],
|
||||
size=self.imgsz,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
if self.model.non_overlap_masks_for_mem_enc:
|
||||
high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks)
|
||||
consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder(
|
||||
batch_size=batch_size,
|
||||
high_res_masks=high_res_masks,
|
||||
is_mask_from_pts=True, # these frames are what the user interacted with
|
||||
object_score_logits=consolidated_out["object_score_logits"],
|
||||
)
|
||||
|
||||
return consolidated_out
|
||||
|
||||
def _get_empty_mask_ptr(self, frame_idx):
|
||||
"""
|
||||
Get a dummy object pointer based on an empty mask on the current frame.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame for which to generate the dummy object pointer.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"])
|
||||
|
||||
# Feed the empty mask and image feature above to get a dummy object pointer
|
||||
current_out = self.model.track_step(
|
||||
frame_idx=frame_idx,
|
||||
is_init_cond_frame=True,
|
||||
current_vision_feats=current_vision_feats,
|
||||
current_vision_pos_embeds=current_vision_pos_embeds,
|
||||
feat_sizes=feat_sizes,
|
||||
point_inputs=None,
|
||||
# A dummy (empty) mask with a single object
|
||||
mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device),
|
||||
output_dict={},
|
||||
num_frames=self.inference_state["num_frames"],
|
||||
track_in_reverse=False,
|
||||
run_mem_encoder=False,
|
||||
prev_sam_mask_logits=None,
|
||||
)
|
||||
return current_out["obj_ptr"]
|
||||
|
||||
def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts):
|
||||
"""
|
||||
Run the memory encoder on masks.
|
||||
|
||||
This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their
|
||||
memory also needs to be computed again with the memory encoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size for processing the frame.
|
||||
high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory.
|
||||
object_score_logits (torch.Tensor): Logits representing the object scores.
|
||||
is_mask_from_pts (bool): Indicates if the mask is derived from point interactions.
|
||||
|
||||
Returns:
|
||||
(tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding.
|
||||
"""
|
||||
# Retrieve correct image features
|
||||
current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size)
|
||||
maskmem_features, maskmem_pos_enc = self.model._encode_new_memory(
|
||||
current_vision_feats=current_vision_feats,
|
||||
feat_sizes=feat_sizes,
|
||||
pred_masks_high_res=high_res_masks,
|
||||
is_mask_from_pts=is_mask_from_pts,
|
||||
object_score_logits=object_score_logits,
|
||||
)
|
||||
|
||||
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
||||
maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc)
|
||||
return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc
|
||||
|
||||
def _add_output_per_object(self, frame_idx, current_out, storage_key):
|
||||
"""
|
||||
Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj.
|
||||
|
||||
The resulting slices share the same tensor storage.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame.
|
||||
current_out (Dict): The current output dictionary containing multi-object outputs.
|
||||
storage_key (str): The key used to store the output in the per-object output dictionary.
|
||||
"""
|
||||
maskmem_features = current_out["maskmem_features"]
|
||||
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
||||
|
||||
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
||||
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
||||
|
||||
for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items():
|
||||
obj_slice = slice(obj_idx, obj_idx + 1)
|
||||
obj_out = {
|
||||
"maskmem_features": None,
|
||||
"maskmem_pos_enc": None,
|
||||
"pred_masks": current_out["pred_masks"][obj_slice],
|
||||
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
||||
}
|
||||
if maskmem_features is not None:
|
||||
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
||||
if maskmem_pos_enc is not None:
|
||||
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
||||
obj_output_dict[storage_key][frame_idx] = obj_out
|
||||
|
||||
def _clear_non_cond_mem_around_input(self, frame_idx):
|
||||
"""
|
||||
Remove the non-conditioning memory around the input frame.
|
||||
|
||||
When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated
|
||||
object appearance information and could confuse the model. This method clears those non-conditioning memories
|
||||
surrounding the interacted frame to avoid giving the model both old and new information about the object.
|
||||
|
||||
Args:
|
||||
frame_idx (int): The index of the current frame where user interaction occurred.
|
||||
"""
|
||||
r = self.model.memory_temporal_stride_for_eval
|
||||
frame_idx_begin = frame_idx - r * self.model.num_maskmem
|
||||
frame_idx_end = frame_idx + r * self.model.num_maskmem
|
||||
for t in range(frame_idx_begin, frame_idx_end + 1):
|
||||
self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None)
|
||||
for obj_output_dict in self.inference_state["output_dict_per_obj"].values():
|
||||
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class BaseTrack:
|
|||
start_frame (int): The frame number where tracking started.
|
||||
frame_id (int): The most recent frame ID processed by the track.
|
||||
time_since_update (int): Frames passed since the last update.
|
||||
location (Tuple): The location of the object in the context of multi-camera tracking.
|
||||
location (tuple): The location of the object in the context of multi-camera tracking.
|
||||
|
||||
Methods:
|
||||
end_frame: Returns the ID of the last frame where the object was tracked.
|
||||
|
|
|
|||
|
|
@ -27,10 +27,9 @@ def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = Tr
|
|||
use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing:
|
||||
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
||||
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
||||
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
||||
matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
|
||||
unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
|
||||
unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
|
||||
|
||||
Examples:
|
||||
>>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
|
|
|
|||
|
|
@ -552,9 +552,8 @@ class v8PoseLoss(v8DetectionLoss):
|
|||
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
||||
|
||||
Returns:
|
||||
(tuple): Returns a tuple containing:
|
||||
- kpts_loss (torch.Tensor): The keypoints loss.
|
||||
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||
kpts_loss (torch.Tensor): The keypoints loss.
|
||||
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||
"""
|
||||
batch_idx = batch_idx.flatten()
|
||||
batch_size = len(masks)
|
||||
|
|
|
|||
|
|
@ -549,19 +549,18 @@ def ap_per_class(
|
|||
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple of six arrays and one array of unique classes, where:
|
||||
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
|
||||
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
|
||||
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
|
||||
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
|
||||
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
|
||||
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
|
||||
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
||||
prec_values: Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
||||
tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,).
|
||||
fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,).
|
||||
ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10).
|
||||
unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,).
|
||||
p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000).
|
||||
r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000).
|
||||
f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000).
|
||||
x (np.ndarray): X-axis values for the curves. Shape: (1000,).
|
||||
prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000).
|
||||
"""
|
||||
# Sort by objectness
|
||||
i = np.argsort(-conf)
|
||||
|
|
|
|||
|
|
@ -317,11 +317,11 @@ def clip_boxes(boxes, shape):
|
|||
Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
boxes (torch.Tensor): The bounding boxes to clip.
|
||||
shape (tuple): The shape of the image.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | numpy.ndarray): Clipped boxes
|
||||
(torch.Tensor | numpy.ndarray): The clipped boxes.
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug)
|
||||
boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1
|
||||
|
|
@ -359,9 +359,9 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|||
Takes a mask, and resizes it to the original image size.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
im0_shape (tuple): the original image shape
|
||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||
masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3].
|
||||
im0_shape (tuple): The original image shape.
|
||||
ratio_pad (tuple): The ratio of the padding to the original image.
|
||||
|
||||
Returns:
|
||||
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
|
||||
|
|
@ -692,12 +692,12 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
|||
|
||||
Args:
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms.
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms.
|
||||
shape (tuple): The size of the input image (h,w).
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n].
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
||||
|
|
|
|||
|
|
@ -584,8 +584,8 @@ class Annotator:
|
|||
Displays queue counts on an image centered at the points with customizable font size and colors.
|
||||
|
||||
Args:
|
||||
label (str): queue counts label
|
||||
points (tuple): region points for center point calculation to display text
|
||||
label (str): Queue counts label.
|
||||
points (tuple): Region points for center point calculation to display text.
|
||||
region_color (tuple): RGB queue region color.
|
||||
txt_color (tuple): RGB text display color.
|
||||
"""
|
||||
|
|
@ -624,13 +624,13 @@ class Annotator:
|
|||
Display the bounding boxes labels in parking management app.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): inference image
|
||||
text (str): object/class name
|
||||
txt_color (tuple): display color for text foreground
|
||||
bg_color (tuple): display color for text background
|
||||
x_center (float): x position center point for bounding box
|
||||
y_center (float): y position center point for bounding box
|
||||
margin (int): gap between text and rectangle for better display
|
||||
im0 (ndarray): Inference image.
|
||||
text (str): Object/class name.
|
||||
txt_color (tuple): Display color for text foreground.
|
||||
bg_color (tuple): Display color for text background.
|
||||
x_center (float): The x position center point for bounding box.
|
||||
y_center (float): The y position center point for bounding box.
|
||||
margin (int): The gap between text and rectangle for better display.
|
||||
"""
|
||||
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
||||
text_x = x_center - text_size[0] // 2
|
||||
|
|
@ -648,11 +648,11 @@ class Annotator:
|
|||
Display the overall statistics for parking lots.
|
||||
|
||||
Args:
|
||||
im0 (ndarray): inference image
|
||||
text (dict): labels dictionary
|
||||
txt_color (tuple): display color for text foreground
|
||||
bg_color (tuple): display color for text background
|
||||
margin (int): gap between text and rectangle for better display
|
||||
im0 (ndarray): Inference image.
|
||||
text (dict): Labels dictionary.
|
||||
txt_color (tuple): Display color for text foreground.
|
||||
bg_color (tuple): Display color for text background.
|
||||
margin (int): Gap between text and rectangle for better display.
|
||||
"""
|
||||
horizontal_gap = int(im0.shape[1] * 0.02)
|
||||
vertical_gap = int(im0.shape[0] * 0.01)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue