ultralytics 8.2.70 Segment Anything Model 2 (SAM 2) (#14813)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-07-30 22:06:49 +08:00 committed by GitHub
parent 80f699ae21
commit 8648572809
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 3276 additions and 77 deletions

View file

@ -14,7 +14,7 @@ from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import ImageEncoderViT, PromptEncoder
from .modules.sam import Sam
from .modules.sam import SAMModel
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
@ -105,7 +105,7 @@ def _build_sam(
out_chans=prompt_embed_dim,
)
)
sam = Sam(
sam = SAMModel(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,

View file

@ -44,6 +44,7 @@ class SAM(Model):
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
@ -54,7 +55,12 @@ class SAM(Model):
weights (str): Path to the weights file.
task (str, optional): Task name. Defaults to None.
"""
self.model = build_sam(weights)
if self.is_sam2:
from ..sam2.build import build_sam2
self.model = build_sam2(weights)
else:
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""
@ -112,4 +118,6 @@ class SAM(Model):
Returns:
(dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'.
"""
return {"segment": {"predictor": Predictor}}
from ..sam2.predict import SAM2Predictor
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}

View file

@ -4,9 +4,8 @@ from typing import List, Tuple, Type
import torch
from torch import nn
from torch.nn import functional as F
from ultralytics.nn.modules import LayerNorm2d
from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
@ -28,7 +27,6 @@ class MaskDecoder(nn.Module):
def __init__(
self,
*,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
@ -149,42 +147,3 @@ class MaskDecoder(nn.Module):
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class MLP(nn.Module):
"""
MLP (Multi-Layer Perceptron) model lightly adapted from
https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
"""
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
sigmoid_output: bool = False,
) -> None:
"""
Initializes the MLP (Multi-Layer Perceptron) model.
Args:
input_dim (int): The dimensionality of the input features.
hidden_dim (int): The dimensionality of the hidden layers.
output_dim (int): The dimensionality of the output layer.
num_layers (int): The number of hidden layers.
sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False.
"""
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
self.sigmoid_output = sigmoid_output
def forward(self, x):
"""Executes feedforward within the neural network module and applies activation."""
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
if self.sigmoid_output:
x = torch.sigmoid(x)
return x

View file

@ -211,6 +211,8 @@ class PromptEncoder(nn.Module):
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding[labels == 2] += self.point_embeddings[2].weight
point_embedding[labels == 3] += self.point_embeddings[3].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
@ -226,8 +228,8 @@ class PromptEncoder(nn.Module):
"""Embeds mask inputs."""
return self.mask_downscaling(masks)
@staticmethod
def _get_batch_size(
self,
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
boxes: Optional[torch.Tensor],
masks: Optional[torch.Tensor],

View file

@ -15,15 +15,14 @@ from .decoders import MaskDecoder
from .encoders import ImageEncoderViT, PromptEncoder
class Sam(nn.Module):
class SAMModel(nn.Module):
"""
Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image
embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask
decoder to predict object masks.
SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate
image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by
the mask decoder to predict object masks.
Attributes:
mask_threshold (float): Threshold value for mask prediction.
image_format (str): Format of the input image, default is 'RGB'.
image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings.
@ -32,7 +31,6 @@ class Sam(nn.Module):
"""
mask_threshold: float = 0.0
image_format: str = "RGB"
def __init__(
self,
@ -43,7 +41,7 @@ class Sam(nn.Module):
pixel_std: List[float] = (58.395, 57.12, 57.375),
) -> None:
"""
Initialize the Sam class to predict object masks from an image and input prompts.
Initialize the SAMModel class to predict object masks from an image and input prompts.
Note:
All forward() operations moved to SAMPredictor.

View file

@ -86,7 +86,6 @@ class TwoWayTransformer(nn.Module):
(torch.Tensor): the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
bs, c, h, w = image_embedding.shape
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
@ -212,6 +211,7 @@ class Attention(nn.Module):
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
kv_in_dim: int = None,
) -> None:
"""
Initializes the Attention model with the given dimensions and settings.
@ -226,13 +226,14 @@ class Attention(nn.Module):
"""
super().__init__()
self.embedding_dim = embedding_dim
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
@staticmethod

View file

@ -168,7 +168,7 @@ class Predictor(BasePredictor):
- np.ndarray: An array of length C containing quality scores predicted by the model for each mask.
- np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256.
"""
features = self.model.image_encoder(im) if self.features is None else self.features
features = self.get_im_features(im) if self.features is None else self.features
src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:]
r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1])
@ -334,7 +334,7 @@ class Predictor(BasePredictor):
"""
device = select_device(self.args.device, verbose=verbose)
if model is None:
model = build_sam(self.args.model)
model = self.get_model()
model.eval()
self.model = model.to(device)
self.device = device
@ -348,6 +348,10 @@ class Predictor(BasePredictor):
self.model.fp16 = False
self.done_warmup = True
def get_model(self):
"""Built Segment Anything Model (SAM) model."""
return build_sam(self.args.model)
def postprocess(self, preds, img, orig_imgs):
"""
Post-processes SAM's inference outputs to generate object detection masks and bounding boxes.
@ -412,16 +416,18 @@ class Predictor(BasePredictor):
AssertionError: If more than one image is set.
"""
if self.model is None:
model = build_sam(self.args.model)
self.setup_model(model)
self.setup_model(model=None)
self.setup_source(image)
assert len(self.dataset) == 1, "`set_image` only supports setting one image!"
for batch in self.dataset:
im = self.preprocess(batch[1])
self.features = self.model.image_encoder(im)
self.im = im
self.features = self.get_im_features(im)
break
def get_im_features(self, im):
"""Get image features from the SAM image encoder."""
return self.model.image_encoder(im)
def set_prompts(self, prompts):
"""Set prompts in advance."""
self.prompts = prompts