ultralytics 8.2.73 Meta SAM2 Refactor (#14867)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
bea4c93278
commit
5d9046abda
44 changed files with 4542 additions and 3624 deletions
|
|
@ -34,35 +34,64 @@ from .build import build_sam
|
|||
|
||||
class Predictor(BasePredictor):
|
||||
"""
|
||||
Predictor class for the Segment Anything Model (SAM), extending BasePredictor.
|
||||
Predictor class for SAM, enabling real-time image segmentation with promptable capabilities.
|
||||
|
||||
The class provides an interface for model inference tailored to image segmentation tasks.
|
||||
With advanced architecture and promptable segmentation capabilities, it facilitates flexible and real-time
|
||||
mask generation. The class is capable of working with various types of prompts such as bounding boxes,
|
||||
points, and low-resolution masks.
|
||||
This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image
|
||||
segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for
|
||||
fine-grained control over segmentation results.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary specifying model and task-related parameters.
|
||||
overrides (dict): Dictionary containing values that override the default configuration.
|
||||
_callbacks (dict): Dictionary of user-defined callback functions to augment behavior.
|
||||
args (namespace): Namespace to hold command-line arguments or other operational variables.
|
||||
im (torch.Tensor): Preprocessed input image tensor.
|
||||
features (torch.Tensor): Extracted image features used for inference.
|
||||
prompts (dict): Collection of various prompt types, such as bounding boxes and points.
|
||||
segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones.
|
||||
args (SimpleNamespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded SAM model.
|
||||
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
||||
im (torch.Tensor): The preprocessed input image.
|
||||
features (torch.Tensor): Extracted image features.
|
||||
prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks).
|
||||
segment_all (bool): Flag to indicate if full image segmentation should be performed.
|
||||
mean (torch.Tensor): Mean values for image normalization.
|
||||
std (torch.Tensor): Standard deviation values for image normalization.
|
||||
|
||||
Methods:
|
||||
preprocess: Prepares input images for model inference.
|
||||
pre_transform: Performs initial transformations on the input image.
|
||||
inference: Performs segmentation inference based on input prompts.
|
||||
prompt_inference: Internal function for prompt-based segmentation inference.
|
||||
generate: Generates segmentation masks for an entire image.
|
||||
setup_model: Initializes the SAM model for inference.
|
||||
get_model: Builds and returns a SAM model.
|
||||
postprocess: Post-processes model outputs to generate final results.
|
||||
setup_source: Sets up the data source for inference.
|
||||
set_image: Sets and preprocesses a single image for inference.
|
||||
get_im_features: Extracts image features using the SAM image encoder.
|
||||
set_prompts: Sets prompts for subsequent inference.
|
||||
reset_image: Resets the current image and its features.
|
||||
remove_small_regions: Removes small disconnected regions and holes from masks.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> masks, scores, boxes = predictor.generate()
|
||||
>>> results = predictor.postprocess((masks, scores, boxes), im, orig_img)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the Predictor with configuration, overrides, and callbacks.
|
||||
|
||||
The method sets up the Predictor object and applies any configuration overrides or callbacks provided. It
|
||||
initializes task-specific settings for SAM, such as retina_masks being set to True for optimal results.
|
||||
Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or
|
||||
callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True
|
||||
for optimal results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary.
|
||||
overrides (dict, optional): Dictionary of values to override default configuration.
|
||||
_callbacks (dict, optional): Dictionary of callback functions to customize behavior.
|
||||
cfg (Dict): Configuration dictionary containing default settings.
|
||||
overrides (Dict | None): Dictionary of values to override default configuration.
|
||||
_callbacks (Dict | None): Dictionary of callback functions to customize behavior.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor(cfg=DEFAULT_CFG)
|
||||
>>> predictor = Predictor(overrides={'imgsz': 640})
|
||||
>>> predictor = Predictor(_callbacks={'on_predict_start': custom_callback})
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
|
|
@ -78,14 +107,19 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Preprocess the input image for model inference.
|
||||
|
||||
The method prepares the input image by applying transformations and normalization.
|
||||
It supports both torch.Tensor and list of np.ndarray as input formats.
|
||||
This method prepares the input image by applying transformations and normalization. It supports both
|
||||
torch.Tensor and list of np.ndarray as input formats.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | List[np.ndarray]): BCHW tensor format or list of HWC numpy arrays.
|
||||
im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The preprocessed image tensor.
|
||||
(torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> image = torch.rand(1, 3, 640, 640)
|
||||
>>> preprocessed_image = predictor.preprocess(image)
|
||||
"""
|
||||
if self.im is not None:
|
||||
return self.im
|
||||
|
|
@ -106,14 +140,24 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Perform initial transformations on the input image for preprocessing.
|
||||
|
||||
The method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
This method applies transformations such as resizing to prepare the image for further preprocessing.
|
||||
Currently, batched inference is not supported; hence the list length should be 1.
|
||||
|
||||
Args:
|
||||
im (List[np.ndarray]): List containing images in HWC numpy array format.
|
||||
im (List[np.ndarray]): List containing a single image in HWC numpy array format.
|
||||
|
||||
Returns:
|
||||
(List[np.ndarray]): List of transformed images.
|
||||
(List[np.ndarray]): List containing the transformed image.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the input list contains more than one image.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> image = np.random.rand(480, 640, 3) # Single HWC image
|
||||
>>> transformed = predictor.pre_transform([image])
|
||||
>>> print(len(transformed))
|
||||
1
|
||||
"""
|
||||
assert len(im) == 1, "SAM model does not currently support batched inference"
|
||||
letterbox = LetterBox(self.args.imgsz, auto=False, center=False)
|
||||
|
|
@ -121,23 +165,32 @@ class Predictor(BasePredictor):
|
|||
|
||||
def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs):
|
||||
"""
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image. This
|
||||
method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and
|
||||
mask decoder for real-time and promptable segmentation tasks.
|
||||
Perform image segmentation inference based on the given input cues, using the currently loaded image.
|
||||
|
||||
This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt
|
||||
encoder, and mask decoder for real-time and promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
*args (Any): Additional positional arguments.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
(tuple): Contains the following three elements:
|
||||
- np.ndarray: The output masks in shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
- np.ndarray: Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model_path='sam_model.pt')
|
||||
>>> predictor.set_image('image.jpg')
|
||||
>>> masks, scores, logits = predictor.inference(im, bboxes=[[0, 0, 100, 100]])
|
||||
"""
|
||||
# Override prompts if any stored in self.prompts
|
||||
bboxes = self.prompts.pop("bboxes", bboxes)
|
||||
|
|
@ -151,22 +204,30 @@ class Predictor(BasePredictor):
|
|||
|
||||
def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False):
|
||||
"""
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
Performs image segmentation inference based on input cues using SAM's specialized architecture.
|
||||
|
||||
This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation.
|
||||
It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256.
|
||||
multimask_output (bool, optional): Flag to return multiple masks. Helpful for ambiguous prompts.
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List | None): Point prompt labels with shape (N,). 1 for foreground, 0 for background.
|
||||
masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256.
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
|
||||
Returns:
|
||||
(tuple): Contains the following three elements.
|
||||
- np.ndarray: The output masks in shape CxHxW, where C is the number of generated masks.
|
||||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores predicted by the model for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, H, W) for subsequent inference, where H=W=256.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> im = torch.rand(1, 3, 1024, 1024)
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes)
|
||||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
|
|
@ -224,27 +285,32 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Perform image segmentation using the Segment Anything Model (SAM).
|
||||
|
||||
This function segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
This method segments an entire image into constituent parts by leveraging SAM's advanced architecture
|
||||
and real-time performance capabilities. It can optionally work on image crops for finer segmentation.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with dimensions (N, C, H, W).
|
||||
crop_n_layers (int): Specifies the number of layers for additional mask predictions on image crops.
|
||||
Each layer produces 2**i_layer number of image crops.
|
||||
crop_overlap_ratio (float): Determines the overlap between crops. Scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for the number of sampled points-per-side in each layer.
|
||||
point_grids (list[np.ndarray], optional): Custom grids for point sampling normalized to [0,1].
|
||||
Used in the nth crop layer.
|
||||
points_stride (int, optional): Number of points to sample along each side of the image.
|
||||
Exclusive with 'point_grids'.
|
||||
im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W).
|
||||
crop_n_layers (int): Number of layers for additional mask predictions on image crops.
|
||||
crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers.
|
||||
crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer.
|
||||
point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1].
|
||||
points_stride (int): Number of points to sample along each side of the image.
|
||||
points_batch_size (int): Batch size for the number of points processed simultaneously.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on the model's mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on mask stability.
|
||||
conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction.
|
||||
stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability.
|
||||
stability_score_offset (float): Offset value for calculating stability score.
|
||||
crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing segmented masks, confidence scores, and bounding boxes.
|
||||
(Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing:
|
||||
- pred_masks (torch.Tensor): Segmented masks with shape (N, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,).
|
||||
- pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4).
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> im = torch.rand(1, 3, 1024, 1024) # Example input image
|
||||
>>> masks, scores, boxes = predictor.generate(im)
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
|
|
@ -326,11 +392,9 @@ class Predictor(BasePredictor):
|
|||
model (torch.nn.Module): A pre-trained SAM model. If None, a model will be built based on configuration.
|
||||
verbose (bool): If True, prints selected device information.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The SAM model allocated to the chosen device for inference.
|
||||
device (torch.device): The device to which the model and tensors are allocated.
|
||||
mean (torch.Tensor): The mean values for image normalization.
|
||||
std (torch.Tensor): The standard deviation values for image normalization.
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_model(model=sam_model, verbose=True)
|
||||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
|
|
@ -349,23 +413,32 @@ class Predictor(BasePredictor):
|
|||
self.done_warmup = True
|
||||
|
||||
def get_model(self):
|
||||
"""Built Segment Anything Model (SAM) model."""
|
||||
"""Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
||||
The method scales masks and boxes to the original image size and applies a threshold to the mask predictions.
|
||||
The SAM model uses advanced architecture and promptable segmentation tasks to achieve real-time performance.
|
||||
This method scales masks and boxes to the original image size and applies a threshold to the mask
|
||||
predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks.
|
||||
|
||||
Args:
|
||||
preds (tuple): The output from SAM model inference, containing masks, scores, and optional bounding boxes.
|
||||
img (torch.Tensor): The processed input image tensor.
|
||||
orig_imgs (list | torch.Tensor): The original, unprocessed images.
|
||||
preds (Tuple[torch.Tensor]): The output from SAM model inference, containing:
|
||||
- pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W).
|
||||
- pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1).
|
||||
- pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True.
|
||||
img (torch.Tensor): The processed input image tensor with shape (C, H, W).
|
||||
orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing detection masks, bounding boxes, and other metadata.
|
||||
(List[Results]): List of Results objects containing detection masks, bounding boxes, and other
|
||||
metadata for each processed image.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> preds = predictor.inference(img)
|
||||
>>> results = predictor.postprocess(preds, img, orig_imgs)
|
||||
"""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
|
|
@ -393,11 +466,23 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Sets up the data source for inference.
|
||||
|
||||
This method configures the data source from which images will be fetched for inference. The source could be a
|
||||
directory, a video file, or other types of image data sources.
|
||||
This method configures the data source from which images will be fetched for inference. It supports
|
||||
various input types such as image files, directories, video files, and other compatible data sources.
|
||||
|
||||
Args:
|
||||
source (str | Path): The path to the image data source for inference.
|
||||
source (str | Path | None): The path or identifier for the image data source. Can be a file path,
|
||||
directory path, URL, or other supported source types.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.setup_source('path/to/images')
|
||||
>>> predictor.setup_source('video.mp4')
|
||||
>>> predictor.setup_source(None) # Uses default source if available
|
||||
|
||||
Notes:
|
||||
- If source is None, the method may use a default source if configured.
|
||||
- The method adapts to different source types and prepares them for subsequent inference steps.
|
||||
- Supported source types may include local files, directories, URLs, and video streams.
|
||||
"""
|
||||
if source is not None:
|
||||
super().setup_source(source)
|
||||
|
|
@ -406,14 +491,25 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
Preprocesses and sets a single image for inference.
|
||||
|
||||
This function sets up the model if not already initialized, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. Only one image can be set at a time.
|
||||
This method prepares the model for inference on a single image by setting up the model if not already
|
||||
initialized, configuring the data source, and preprocessing the image for feature extraction. It
|
||||
ensures that only one image is set at a time and extracts image features for subsequent use.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Image file path as a string, or a np.ndarray image read by cv2.
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing
|
||||
an image read by cv2.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is set.
|
||||
AssertionError: If more than one image is attempted to be set.
|
||||
|
||||
Examples:
|
||||
>>> predictor = Predictor()
|
||||
>>> predictor.set_image('path/to/image.jpg')
|
||||
>>> predictor.set_image(cv2.imread('path/to/image.jpg'))
|
||||
|
||||
Notes:
|
||||
- This method should be called before performing inference on a new image.
|
||||
- The extracted features are stored in the `self.features` attribute for later use.
|
||||
"""
|
||||
if self.model is None:
|
||||
self.setup_model(model=None)
|
||||
|
|
@ -425,35 +521,44 @@ class Predictor(BasePredictor):
|
|||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Get image features from the SAM image encoder."""
|
||||
"""Extracts image features using the SAM model's image encoder for subsequent mask prediction."""
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
"""Sets prompts for subsequent inference operations."""
|
||||
self.prompts = prompts
|
||||
|
||||
def reset_image(self):
|
||||
"""Resets the image and its features to None."""
|
||||
"""Resets the current image and its features, clearing them for subsequent inference."""
|
||||
self.im = None
|
||||
self.features = None
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(masks, min_area=0, nms_thresh=0.7):
|
||||
"""
|
||||
Perform post-processing on segmentation masks generated by the Segment Anything Model (SAM). Specifically, this
|
||||
function removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Remove small disconnected regions and holes from segmentation masks.
|
||||
|
||||
This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM).
|
||||
It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum
|
||||
Suppression (NMS) to eliminate any newly created duplicate boxes.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): A tensor containing the masks to be processed. Shape should be (N, H, W), where N is
|
||||
the number of masks, H is height, and W is width.
|
||||
min_area (int): The minimum area below which disconnected regions and holes will be removed. Defaults to 0.
|
||||
nms_thresh (float): The IoU threshold for the NMS algorithm. Defaults to 0.7.
|
||||
masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of
|
||||
masks, H is height, and W is width.
|
||||
min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than
|
||||
this will be removed.
|
||||
nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes.
|
||||
|
||||
Returns:
|
||||
(tuple([torch.Tensor, List[int]])):
|
||||
- new_masks (torch.Tensor): The processed masks with small regions removed. Shape is (N, H, W).
|
||||
- keep (List[int]): The indices of the remaining masks post-NMS, which can be used to filter the boxes.
|
||||
(tuple):
|
||||
- new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W).
|
||||
- keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks
|
||||
>>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7)
|
||||
>>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}")
|
||||
>>> print(f"Indices of kept masks: {keep}")
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
|
|
@ -480,3 +585,188 @@ class Predictor(BasePredictor):
|
|||
keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh)
|
||||
|
||||
return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep
|
||||
|
||||
|
||||
class SAM2Predictor(Predictor):
|
||||
"""
|
||||
SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture.
|
||||
|
||||
This class extends the base Predictor class to implement SAM2-specific functionality for image
|
||||
segmentation tasks. It provides methods for model initialization, feature extraction, and
|
||||
prompt-based inference.
|
||||
|
||||
Attributes:
|
||||
_bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels.
|
||||
model (torch.nn.Module): The loaded SAM2 model.
|
||||
device (torch.device): The device (CPU or GPU) on which the model is loaded.
|
||||
features (Dict[str, torch.Tensor]): Cached image features for efficient inference.
|
||||
segment_all (bool): Flag to indicate if all segments should be predicted.
|
||||
prompts (Dict): Dictionary to store various types of prompts for inference.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieves and initializes the SAM2 model.
|
||||
prompt_inference: Performs image segmentation inference based on various prompts.
|
||||
set_image: Preprocesses and sets a single image for inference.
|
||||
get_im_features: Extracts and processes image features using SAM2's image encoder.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, _ = predictor.prompt_inference(predictor.im, bboxes=bboxes)
|
||||
>>> print(f"Predicted {len(masks)} masks with average score {scores.mean():.2f}")
|
||||
"""
|
||||
|
||||
_bb_feat_sizes = [
|
||||
(256, 256),
|
||||
(128, 128),
|
||||
(64, 64),
|
||||
]
|
||||
|
||||
def get_model(self):
|
||||
"""Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def prompt_inference(
|
||||
self,
|
||||
im,
|
||||
bboxes=None,
|
||||
points=None,
|
||||
labels=None,
|
||||
masks=None,
|
||||
multimask_output=False,
|
||||
img_idx=-1,
|
||||
):
|
||||
"""
|
||||
Performs image segmentation inference based on various prompts using SAM2 architecture.
|
||||
|
||||
This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images
|
||||
based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and
|
||||
multi-object prediction scenarios.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W).
|
||||
bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4).
|
||||
points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background.
|
||||
masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W).
|
||||
multimask_output (bool): Flag to return multiple masks for ambiguous prompts.
|
||||
img_idx (int): Index of the image in the batch to process.
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple containing:
|
||||
- np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks.
|
||||
- np.ndarray: Quality scores for each mask, with length C.
|
||||
- np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor(cfg)
|
||||
>>> image = torch.rand(1, 3, 640, 640)
|
||||
>>> bboxes = [[100, 100, 200, 200]]
|
||||
>>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes)
|
||||
>>> print(f"Generated {masks.shape[0]} masks with average score {scores.mean():.2f}")
|
||||
|
||||
Notes:
|
||||
- The method supports batched inference for multiple objects when points or bboxes are provided.
|
||||
- Input prompts (bboxes, points) are automatically scaled to match the input image dimensions.
|
||||
- When both bboxes and points are provided, they are merged into a single 'points' input for the model.
|
||||
|
||||
References:
|
||||
- SAM2 Paper: [Add link to SAM2 paper when available]
|
||||
"""
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
# Transform input prompts
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.float32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
# Assuming labels are all positive if users don't pass labels.
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
points *= r
|
||||
# (N, 2) --> (N, 1, 2), (N, ) --> (N, 1)
|
||||
points, labels = points[:, None], labels[:, None]
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bboxes = bboxes.view(-1, 2, 2) * r
|
||||
bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1)
|
||||
# NOTE: merge "boxes" and "points" into a single "points" input
|
||||
# (where boxes are added at the beginning) to model.sam_prompt_encoder
|
||||
if points is not None:
|
||||
points = torch.cat([bboxes, points], dim=1)
|
||||
labels = torch.cat([bbox_labels, labels], dim=1)
|
||||
else:
|
||||
points, labels = bboxes, bbox_labels
|
||||
if masks is not None:
|
||||
masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1)
|
||||
|
||||
points = (points, labels) if points is not None else None
|
||||
|
||||
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
||||
points=points,
|
||||
boxes=None,
|
||||
masks=masks,
|
||||
)
|
||||
# Predict masks
|
||||
batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction
|
||||
high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]]
|
||||
pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder(
|
||||
image_embeddings=features["image_embed"][img_idx].unsqueeze(0),
|
||||
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
||||
sparse_prompt_embeddings=sparse_embeddings,
|
||||
dense_prompt_embeddings=dense_embeddings,
|
||||
multimask_output=multimask_output,
|
||||
repeat_image=batched_mode,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
# (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, )
|
||||
# `d` could be 1 or 3 depends on `multimask_output`.
|
||||
return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1)
|
||||
|
||||
def set_image(self, image):
|
||||
"""
|
||||
Preprocesses and sets a single image for inference using the SAM2 model.
|
||||
|
||||
This method initializes the model if not already done, configures the data source to the specified image,
|
||||
and preprocesses the image for feature extraction. It supports setting only one image at a time.
|
||||
|
||||
Args:
|
||||
image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image.
|
||||
|
||||
Raises:
|
||||
AssertionError: If more than one image is attempted to be set.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SAM2Predictor()
|
||||
>>> predictor.set_image("path/to/image.jpg")
|
||||
>>> predictor.set_image(np.array([...])) # Using a numpy array
|
||||
|
||||
Notes:
|
||||
- This method must be called before performing any inference on a new image.
|
||||
- The method caches the extracted features for efficient subsequent inferences on the same image.
|
||||
- Only one image can be set at a time. To process multiple images, call this method for each new image.
|
||||
"""
|
||||
if self.model is None:
|
||||
self.setup_model(model=None)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.get_im_features(im)
|
||||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Extracts image features from the SAM image encoder for subsequent processing."""
|
||||
backbone_out = self.model.forward_image(im)
|
||||
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
||||
if self.model.directly_add_no_mem_embed:
|
||||
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
||||
feats = [
|
||||
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
||||
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
||||
][::-1]
|
||||
return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue