Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
332 lines
13 KiB
Python
332 lines
13 KiB
Python
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ultralytics.models.sam.modules.encoders import PatchEmbed
|
|
|
|
from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine
|
|
|
|
|
|
class MemoryEncoder(nn.Module):
|
|
"""Encodes pixel features and masks into a memory representation for efficient image segmentation."""
|
|
|
|
def __init__(
|
|
self,
|
|
out_dim,
|
|
in_dim=256, # in_dim of pix_feats
|
|
):
|
|
"""Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models."""
|
|
super().__init__()
|
|
|
|
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
|
|
|
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
|
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
|
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
|
|
self.out_proj = nn.Identity()
|
|
if out_dim != in_dim:
|
|
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
|
|
|
def forward(
|
|
self,
|
|
pix_feat: torch.Tensor,
|
|
masks: torch.Tensor,
|
|
skip_mask_sigmoid: bool = False,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""Processes pixel features and masks, fusing them to generate encoded memory representations."""
|
|
if not skip_mask_sigmoid:
|
|
masks = F.sigmoid(masks)
|
|
masks = self.mask_downsampler(masks)
|
|
|
|
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
|
|
pix_feat = pix_feat.to(masks.device)
|
|
|
|
x = self.pix_feat_proj(pix_feat)
|
|
x = x + masks
|
|
x = self.fuser(x)
|
|
x = self.out_proj(x)
|
|
|
|
pos = self.position_encoding(x).to(x.dtype)
|
|
|
|
return {"vision_features": x, "vision_pos_enc": [pos]}
|
|
|
|
|
|
class ImageEncoder(nn.Module):
|
|
"""Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings."""
|
|
|
|
def __init__(
|
|
self,
|
|
trunk: nn.Module,
|
|
neck: nn.Module,
|
|
scalp: int = 0,
|
|
):
|
|
"""Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction."""
|
|
super().__init__()
|
|
self.trunk = trunk
|
|
self.neck = neck
|
|
self.scalp = scalp
|
|
assert (
|
|
self.trunk.channel_list == self.neck.backbone_channel_list
|
|
), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
|
|
|
def forward(self, sample: torch.Tensor):
|
|
"""Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs."""
|
|
features, pos = self.neck(self.trunk(sample))
|
|
if self.scalp > 0:
|
|
# Discard the lowest resolution features
|
|
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
|
|
|
src = features[-1]
|
|
output = {
|
|
"vision_features": src,
|
|
"vision_pos_enc": pos,
|
|
"backbone_fpn": features,
|
|
}
|
|
return output
|
|
|
|
|
|
class FpnNeck(nn.Module):
|
|
"""Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models."""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model: int,
|
|
backbone_channel_list: List[int],
|
|
kernel_size: int = 1,
|
|
stride: int = 1,
|
|
padding: int = 0,
|
|
fpn_interp_model: str = "bilinear",
|
|
fuse_type: str = "sum",
|
|
fpn_top_down_levels: Optional[List[int]] = None,
|
|
):
|
|
"""
|
|
Initializes a modified Feature Pyramid Network (FPN) neck.
|
|
|
|
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
|
similar to ViT positional embedding interpolation.
|
|
|
|
Args:
|
|
d_model (int): Dimension of the model.
|
|
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
|
kernel_size (int): Kernel size for the convolutional layers.
|
|
stride (int): Stride for the convolutional layers.
|
|
padding (int): Padding for the convolutional layers.
|
|
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
|
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
|
fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs.
|
|
|
|
Attributes:
|
|
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding.
|
|
convs (nn.ModuleList): List of convolutional layers for each backbone level.
|
|
backbone_channel_list (List[int]): List of channel dimensions from the backbone.
|
|
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
|
fuse_type (str): Type of feature fusion.
|
|
fpn_top_down_levels (List[int]): Levels with top-down feature propagation.
|
|
|
|
Examples:
|
|
>>> backbone_channels = [64, 128, 256, 512]
|
|
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
|
>>> print(fpn_neck)
|
|
"""
|
|
super().__init__()
|
|
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
|
|
self.convs = nn.ModuleList()
|
|
self.backbone_channel_list = backbone_channel_list
|
|
for dim in backbone_channel_list:
|
|
current = nn.Sequential()
|
|
current.add_module(
|
|
"conv",
|
|
nn.Conv2d(
|
|
in_channels=dim,
|
|
out_channels=d_model,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
),
|
|
)
|
|
|
|
self.convs.append(current)
|
|
self.fpn_interp_model = fpn_interp_model
|
|
assert fuse_type in ["sum", "avg"]
|
|
self.fuse_type = fuse_type
|
|
|
|
# levels to have top-down features in its outputs
|
|
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
|
# have top-down propagation, while outputs of level 0 and level 1 have only
|
|
# lateral features from the same backbone level.
|
|
if fpn_top_down_levels is None:
|
|
# default is to have top-down features on all levels
|
|
fpn_top_down_levels = range(len(self.convs))
|
|
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
|
|
|
def forward(self, xs: List[torch.Tensor]):
|
|
"""
|
|
Performs forward pass through the Feature Pyramid Network (FPN) neck.
|
|
|
|
Args:
|
|
xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor.
|
|
|
|
Returns:
|
|
(Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists:
|
|
- out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor.
|
|
- pos: List of positional encodings corresponding to each output feature map.
|
|
|
|
Examples:
|
|
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
|
|
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
|
|
>>> outputs, positions = fpn_neck(inputs)
|
|
"""
|
|
out = [None] * len(self.convs)
|
|
pos = [None] * len(self.convs)
|
|
assert len(xs) == len(self.convs)
|
|
# fpn forward pass
|
|
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
|
prev_features = None
|
|
# forward in top-down order (from low to high resolution)
|
|
n = len(self.convs) - 1
|
|
for i in range(n, -1, -1):
|
|
x = xs[i]
|
|
lateral_features = self.convs[n - i](x)
|
|
if i in self.fpn_top_down_levels and prev_features is not None:
|
|
top_down_features = F.interpolate(
|
|
prev_features.to(dtype=torch.float32),
|
|
scale_factor=2.0,
|
|
mode=self.fpn_interp_model,
|
|
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
|
antialias=False,
|
|
)
|
|
prev_features = lateral_features + top_down_features
|
|
if self.fuse_type == "avg":
|
|
prev_features /= 2
|
|
else:
|
|
prev_features = lateral_features
|
|
x_out = prev_features
|
|
out[i] = x_out
|
|
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
|
|
|
return out, pos
|
|
|
|
|
|
class Hiera(nn.Module):
|
|
"""Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks."""
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int = 96, # initial embed dim
|
|
num_heads: int = 1, # initial number of heads
|
|
drop_path_rate: float = 0.0, # stochastic depth
|
|
q_pool: int = 3, # number of q_pool stages
|
|
q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
|
|
stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
|
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
|
head_mul: float = 2.0, # head_mul factor at stage shift
|
|
window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
|
|
# window size per stage, when not using global att.
|
|
window_spec: Tuple[int, ...] = (
|
|
8,
|
|
4,
|
|
14,
|
|
7,
|
|
),
|
|
# global attn in these blocks
|
|
global_att_blocks: Tuple[int, ...] = (
|
|
12,
|
|
16,
|
|
20,
|
|
),
|
|
return_interm_layers=True, # return feats from every stage
|
|
):
|
|
"""Initializes a Hiera model with configurable architecture for hierarchical vision transformers."""
|
|
super().__init__()
|
|
|
|
assert len(stages) == len(window_spec)
|
|
self.window_spec = window_spec
|
|
|
|
depth = sum(stages)
|
|
self.q_stride = q_stride
|
|
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
|
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
|
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
|
self.return_interm_layers = return_interm_layers
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
embed_dim=embed_dim,
|
|
kernel_size=(7, 7),
|
|
stride=(4, 4),
|
|
padding=(3, 3),
|
|
)
|
|
# Which blocks have global att?
|
|
self.global_att_blocks = global_att_blocks
|
|
|
|
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
|
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
|
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
cur_stage = 1
|
|
self.blocks = nn.ModuleList()
|
|
|
|
for i in range(depth):
|
|
dim_out = embed_dim
|
|
# lags by a block, so first block of
|
|
# next stage uses an initial window size
|
|
# of previous stage and final window size of current stage
|
|
window_size = self.window_spec[cur_stage - 1]
|
|
|
|
if self.global_att_blocks is not None:
|
|
window_size = 0 if i in self.global_att_blocks else window_size
|
|
|
|
if i - 1 in self.stage_ends:
|
|
dim_out = int(embed_dim * dim_mul)
|
|
num_heads = int(num_heads * head_mul)
|
|
cur_stage += 1
|
|
|
|
block = MultiScaleBlock(
|
|
dim=embed_dim,
|
|
dim_out=dim_out,
|
|
num_heads=num_heads,
|
|
drop_path=dpr[i],
|
|
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
|
window_size=window_size,
|
|
)
|
|
|
|
embed_dim = dim_out
|
|
self.blocks.append(block)
|
|
|
|
self.channel_list = (
|
|
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
|
if return_interm_layers
|
|
else [self.blocks[-1].dim_out]
|
|
)
|
|
|
|
def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
|
|
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
|
|
h, w = hw
|
|
window_embed = self.pos_embed_window
|
|
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
|
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
|
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
|
return pos_embed
|
|
|
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
|
"""Performs hierarchical vision transformer forward pass, returning multiscale feature maps."""
|
|
x = self.patch_embed(x)
|
|
# x: (B, H, W, C)
|
|
|
|
# Add pos embed
|
|
x = x + self._get_pos_embed(x.shape[1:3])
|
|
|
|
outputs = []
|
|
for i, blk in enumerate(self.blocks):
|
|
x = blk(x)
|
|
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
|
feats = x.permute(0, 3, 1, 2)
|
|
outputs.append(feats)
|
|
|
|
return outputs
|