ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
80f699ae21
commit
8648572809
36 changed files with 3276 additions and 77 deletions
|
|
@ -14,7 +14,7 @@ from ultralytics.utils.downloads import attempt_download_asset
|
|||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import ImageEncoderViT, PromptEncoder
|
||||
from .modules.sam import Sam
|
||||
from .modules.sam import SAMModel
|
||||
from .modules.tiny_encoder import TinyViT
|
||||
from .modules.transformer import TwoWayTransformer
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ def _build_sam(
|
|||
out_chans=prompt_embed_dim,
|
||||
)
|
||||
)
|
||||
sam = Sam(
|
||||
sam = SAMModel(
|
||||
image_encoder=image_encoder,
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class SAM(Model):
|
|||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
self.is_sam2 = "sam2" in Path(model).stem
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
|
|
@ -54,7 +55,12 @@ class SAM(Model):
|
|||
weights (str): Path to the weights file.
|
||||
task (str, optional): Task name. Defaults to None.
|
||||
"""
|
||||
self.model = build_sam(weights)
|
||||
if self.is_sam2:
|
||||
from ..sam2.build import build_sam2
|
||||
|
||||
self.model = build_sam2(weights)
|
||||
else:
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
|
|
@ -112,4 +118,6 @@ class SAM(Model):
|
|||
Returns:
|
||||
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
|
||||
"""
|
||||
return {"segment": {"predictor": Predictor}}
|
||||
from ..sam2.predict import SAM2Predictor
|
||||
|
||||
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,8 @@ from typing import List, Tuple, Type
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
from ultralytics.nn.modules import MLP, LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
|
|
@ -28,7 +27,6 @@ class MaskDecoder(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
|
|
@ -149,42 +147,3 @@ class MaskDecoder(nn.Module):
|
|||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
MLP (Multi-Layer Perceptron) model lightly adapted from
|
||||
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
sigmoid_output: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the MLP (Multi-Layer Perceptron) model.
|
||||
|
||||
Args:
|
||||
input_dim (int): The dimensionality of the input features.
|
||||
hidden_dim (int): The dimensionality of the hidden layers.
|
||||
output_dim (int): The dimensionality of the output layer.
|
||||
num_layers (int): The number of hidden layers.
|
||||
sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_layers = num_layers
|
||||
h = [hidden_dim] * (num_layers - 1)
|
||||
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
||||
self.sigmoid_output = sigmoid_output
|
||||
|
||||
def forward(self, x):
|
||||
"""Executes feedforward within the neural network module and applies activation."""
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
||||
if self.sigmoid_output:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
|
|
|||
|
|
@ -211,6 +211,8 @@ class PromptEncoder(nn.Module):
|
|||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -226,8 +228,8 @@ class PromptEncoder(nn.Module):
|
|||
"""Embeds mask inputs."""
|
||||
return self.mask_downscaling(masks)
|
||||
|
||||
@staticmethod
|
||||
def _get_batch_size(
|
||||
self,
|
||||
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
||||
boxes: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
|
|
|
|||
|
|
@ -15,15 +15,14 @@ from .decoders import MaskDecoder
|
|||
from .encoders import ImageEncoderViT, PromptEncoder
|
||||
|
||||
|
||||
class Sam(nn.Module):
|
||||
class SAMModel(nn.Module):
|
||||
"""
|
||||
Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
|
||||
embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
|
||||
decoder to predict object masks.
|
||||
SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
|
||||
image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
|
||||
the mask decoder to predict object masks.
|
||||
|
||||
Attributes:
|
||||
mask_threshold (float): Threshold value for mask prediction.
|
||||
image_format (str): Format of the input image, default is 'RGB'.
|
||||
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
|
||||
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
||||
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
|
||||
|
|
@ -32,7 +31,6 @@ class Sam(nn.Module):
|
|||
"""
|
||||
|
||||
mask_threshold: float = 0.0
|
||||
image_format: str = "RGB"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -43,7 +41,7 @@ class Sam(nn.Module):
|
|||
pixel_std: List[float] = (58.395, 57.12, 57.375),
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Sam class to predict object masks from an image and input prompts.
|
||||
Initialize the SAMModel class to predict object masks from an image and input prompts.
|
||||
|
||||
Note:
|
||||
All forward() operations moved to SAMPredictor.
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ class TwoWayTransformer(nn.Module):
|
|||
(torch.Tensor): the processed image_embedding
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
bs, c, h, w = image_embedding.shape
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
|
|
@ -212,6 +211,7 @@ class Attention(nn.Module):
|
|||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
kv_in_dim: int = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the Attention model with the given dimensions and settings.
|
||||
|
|
@ -226,13 +226,14 @@ class Attention(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class Predictor(BasePredictor):
|
|||
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
|
||||
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
|
||||
"""
|
||||
features = self.model.image_encoder(im) if self.features is None else self.features
|
||||
features = self.get_im_features(im) if self.features is None else self.features
|
||||
|
||||
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
|
||||
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
|
||||
|
|
@ -334,7 +334,7 @@ class Predictor(BasePredictor):
|
|||
"""
|
||||
device = select_device(self.args.device, verbose=verbose)
|
||||
if model is None:
|
||||
model = build_sam(self.args.model)
|
||||
model = self.get_model()
|
||||
model.eval()
|
||||
self.model = model.to(device)
|
||||
self.device = device
|
||||
|
|
@ -348,6 +348,10 @@ class Predictor(BasePredictor):
|
|||
self.model.fp16 = False
|
||||
self.done_warmup = True
|
||||
|
||||
def get_model(self):
|
||||
"""Built Segment Anything Model (SAM) model."""
|
||||
return build_sam(self.args.model)
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
|
||||
|
|
@ -412,16 +416,18 @@ class Predictor(BasePredictor):
|
|||
AssertionError: If more than one image is set.
|
||||
"""
|
||||
if self.model is None:
|
||||
model = build_sam(self.args.model)
|
||||
self.setup_model(model)
|
||||
self.setup_model(model=None)
|
||||
self.setup_source(image)
|
||||
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
|
||||
for batch in self.dataset:
|
||||
im = self.preprocess(batch[1])
|
||||
self.features = self.model.image_encoder(im)
|
||||
self.im = im
|
||||
self.features = self.get_im_features(im)
|
||||
break
|
||||
|
||||
def get_im_features(self, im):
|
||||
"""Get image features from the SAM image encoder."""
|
||||
return self.model.image_encoder(im)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
self.prompts = prompts
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue