From 8648572809fa2e58967862f9a4748abddd0f60a7 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 30 Jul 2024 22:06:49 +0800 Subject: [PATCH] `ultralytics 8.2.70` Segment Anything Model 2 (SAM 2) (#14813) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- docs/en/models/sam-2.md | 2 + .../reference/models/sam/modules/decoders.md | 4 - docs/en/reference/models/sam/modules/sam.md | 6 +- docs/en/reference/models/sam2/build.md | 36 + docs/en/reference/models/sam2/model.md | 16 + .../reference/models/sam2/modules/decoders.md | 16 + .../reference/models/sam2/modules/encoders.md | 28 + .../models/sam2/modules/memory_attention.md | 20 + docs/en/reference/models/sam2/modules/sam2.md | 16 + .../models/sam2/modules/sam2_blocks.md | 56 ++ .../en/reference/models/sam2/modules/utils.md | 44 + docs/en/reference/models/sam2/predict.md | 16 + mkdocs.yml | 13 +- ultralytics/__init__.py | 5 +- ultralytics/cfg/__init__.py | 4 + ultralytics/models/__init__.py | 3 +- ultralytics/models/fastsam/predict.py | 1 + ultralytics/models/sam/build.py | 4 +- ultralytics/models/sam/model.py | 12 +- ultralytics/models/sam/modules/decoders.py | 43 +- ultralytics/models/sam/modules/encoders.py | 4 +- ultralytics/models/sam/modules/sam.py | 12 +- ultralytics/models/sam/modules/transformer.py | 7 +- ultralytics/models/sam/predict.py | 18 +- ultralytics/models/sam2/__init__.py | 6 + ultralytics/models/sam2/build.py | 156 ++++ ultralytics/models/sam2/model.py | 97 +++ ultralytics/models/sam2/modules/__init__.py | 1 + ultralytics/models/sam2/modules/decoders.py | 305 +++++++ ultralytics/models/sam2/modules/encoders.py | 332 ++++++++ .../models/sam2/modules/memory_attention.py | 170 ++++ ultralytics/models/sam2/modules/sam2.py | 804 ++++++++++++++++++ .../models/sam2/modules/sam2_blocks.py | 715 ++++++++++++++++ ultralytics/models/sam2/modules/utils.py | 191 +++++ ultralytics/models/sam2/predict.py | 182 ++++ ultralytics/nn/modules/transformer.py | 8 +- 36 files changed, 3276 insertions(+), 77 deletions(-) create mode 100644 docs/en/reference/models/sam2/build.md create mode 100644 docs/en/reference/models/sam2/model.md create mode 100644 docs/en/reference/models/sam2/modules/decoders.md create mode 100644 docs/en/reference/models/sam2/modules/encoders.md create mode 100644 docs/en/reference/models/sam2/modules/memory_attention.md create mode 100644 docs/en/reference/models/sam2/modules/sam2.md create mode 100644 docs/en/reference/models/sam2/modules/sam2_blocks.md create mode 100644 docs/en/reference/models/sam2/modules/utils.md create mode 100644 docs/en/reference/models/sam2/predict.md create mode 100644 ultralytics/models/sam2/__init__.py create mode 100644 ultralytics/models/sam2/build.py create mode 100644 ultralytics/models/sam2/model.py create mode 100644 ultralytics/models/sam2/modules/__init__.py create mode 100644 ultralytics/models/sam2/modules/decoders.py create mode 100644 ultralytics/models/sam2/modules/encoders.py create mode 100644 ultralytics/models/sam2/modules/memory_attention.py create mode 100644 ultralytics/models/sam2/modules/sam2.py create mode 100644 ultralytics/models/sam2/modules/sam2_blocks.py create mode 100644 ultralytics/models/sam2/modules/utils.py create mode 100644 ultralytics/models/sam2/predict.py diff --git a/docs/en/models/sam-2.md b/docs/en/models/sam-2.md index 85d7f24c..001bd89b 100644 --- a/docs/en/models/sam-2.md +++ b/docs/en/models/sam-2.md @@ -113,6 +113,8 @@ The following table details the available SAM 2 models, their pre-trained weight | Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export | | ----------- | ------------------------------------------------------------------------------------- | -------------------------------------------- | --------- | ---------- | -------- | ------ | +| SAM 2 tiny | [sam2_t.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_t.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | +| SAM 2 small | [sam2_s.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_s.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 base | [sam2_b.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_b.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | | SAM 2 large | [sam2_l.pt](https://github.com/ultralytics/assets/releases/download/v8.2.0/sam2_l.pt) | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/en/reference/models/sam/modules/decoders.md b/docs/en/reference/models/sam/modules/decoders.md index ff3aaa9e..a224738b 100644 --- a/docs/en/reference/models/sam/modules/decoders.md +++ b/docs/en/reference/models/sam/modules/decoders.md @@ -13,8 +13,4 @@ keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architect ## ::: ultralytics.models.sam.modules.decoders.MaskDecoder -



- -## ::: ultralytics.models.sam.modules.decoders.MLP -

diff --git a/docs/en/reference/models/sam/modules/sam.md b/docs/en/reference/models/sam/modules/sam.md index 5a7c30e4..a31cece1 100644 --- a/docs/en/reference/models/sam/modules/sam.md +++ b/docs/en/reference/models/sam/modules/sam.md @@ -1,6 +1,6 @@ --- -description: Discover the Ultralytics Sam module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. -keywords: Ultralytics, Sam Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning +description: Discover the Ultralytics SAM module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. +keywords: Ultralytics, SAM Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning --- # Reference for `ultralytics/models/sam/modules/sam.py` @@ -11,6 +11,6 @@ keywords: Ultralytics, Sam Module, object segmentation, image encoder, mask deco
-## ::: ultralytics.models.sam.modules.sam.Sam +## ::: ultralytics.models.sam.modules.sam.SAMModel

diff --git a/docs/en/reference/models/sam2/build.md b/docs/en/reference/models/sam2/build.md new file mode 100644 index 00000000..94e9f5be --- /dev/null +++ b/docs/en/reference/models/sam2/build.md @@ -0,0 +1,36 @@ +--- +description: Discover detailed instructions for building various Segment Anything Model 2 (SAM 2) architectures with Ultralytics. +keywords: Ultralytics, SAM 2 model, Segment Anything Model 2, SAM, model building, deep learning, AI +--- + +# Reference for `ultralytics/models/sam2/build.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/build.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/build.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.build.build_sam2_t + +



+ +## ::: ultralytics.models.sam2.build.build_sam2_s + +



+ +## ::: ultralytics.models.sam2.build.build_sam2_b + +



+ +## ::: ultralytics.models.sam2.build.build_sam2_l + +



+ +## ::: ultralytics.models.sam2.build._build_sam2 + +



+ +## ::: ultralytics.models.sam2.build.build_sam2 + +

diff --git a/docs/en/reference/models/sam2/model.md b/docs/en/reference/models/sam2/model.md new file mode 100644 index 00000000..fc0d2e25 --- /dev/null +++ b/docs/en/reference/models/sam2/model.md @@ -0,0 +1,16 @@ +--- +description: Explore the SAM 2 (Segment Anything Model 2) interface for real-time image segmentation. Learn about promptable segmentation and zero-shot capabilities. +keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time segmentation, zero-shot performance, promptable segmentation, SA-1B dataset +--- + +# Reference for `ultralytics/models/sam2/model.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/model.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/model.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.model.SAM2 + +

diff --git a/docs/en/reference/models/sam2/modules/decoders.md b/docs/en/reference/models/sam2/modules/decoders.md new file mode 100644 index 00000000..989169a7 --- /dev/null +++ b/docs/en/reference/models/sam2/modules/decoders.md @@ -0,0 +1,16 @@ +--- +description: Explore the MaskDecoder and MLP modules in Ultralytics for efficient mask prediction using transformer architecture. Detailed attributes, functionalities, and implementation. +keywords: Ultralytics, MaskDecoder, MLP, machine learning, transformer architecture, mask prediction, neural networks, PyTorch modules +--- + +# Reference for `ultralytics/models/sam2/modules/decoders.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/decoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/decoders.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.decoders.MaskDecoder + +

diff --git a/docs/en/reference/models/sam2/modules/encoders.md b/docs/en/reference/models/sam2/modules/encoders.md new file mode 100644 index 00000000..a3da286e --- /dev/null +++ b/docs/en/reference/models/sam2/modules/encoders.md @@ -0,0 +1,28 @@ +--- +description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. +keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning +--- + +# Reference for `ultralytics/models/sam2/modules/encoders.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/encoders.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/encoders.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.encoders.MemoryEncoder + +



+ +## ::: ultralytics.models.sam2.modules.encoders.ImageEncoder + +



+ +## ::: ultralytics.models.sam2.modules.encoders.FpnNeck + +



+ +## ::: ultralytics.models.sam2.modules.encoders.Hiera + +

diff --git a/docs/en/reference/models/sam2/modules/memory_attention.md b/docs/en/reference/models/sam2/modules/memory_attention.md new file mode 100644 index 00000000..fef22174 --- /dev/null +++ b/docs/en/reference/models/sam2/modules/memory_attention.md @@ -0,0 +1,20 @@ +--- +description: Explore detailed documentation of various SAM 2 encoder modules such as MemoryAttentionLayer, MemoryAttention, available in Ultralytics' repository. +keywords: Ultralytics, SAM 2 encoder, MemoryAttentionLayer, MemoryAttention +--- + +# Reference for `ultralytics/models/sam2/modules/memory_attention.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/memory_attention.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/memory_attention.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttentionLayer + +



+ +## ::: ultralytics.models.sam2.modules.memory_attention.MemoryAttention + +

diff --git a/docs/en/reference/models/sam2/modules/sam2.md b/docs/en/reference/models/sam2/modules/sam2.md new file mode 100644 index 00000000..fcd47bbd --- /dev/null +++ b/docs/en/reference/models/sam2/modules/sam2.md @@ -0,0 +1,16 @@ +--- +description: Discover the Ultralytics SAM 2 module for object segmentation. Learn about its components, such as image encoders and mask decoders, in this comprehensive guide. +keywords: Ultralytics, SAM 2 Module, object segmentation, image encoder, mask decoder, prompt encoder, AI, machine learning +--- + +# Reference for `ultralytics/models/sam2/modules/sam2.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.sam2.SAM2Model + +

diff --git a/docs/en/reference/models/sam2/modules/sam2_blocks.md b/docs/en/reference/models/sam2/modules/sam2_blocks.md new file mode 100644 index 00000000..796669a0 --- /dev/null +++ b/docs/en/reference/models/sam2/modules/sam2_blocks.md @@ -0,0 +1,56 @@ +--- +description: Explore detailed documentation of various SAM 2 modules such as MaskDownSampler, CXBlock, and more, available in Ultralytics' repository. +keywords: Ultralytics, SAM 2 encoder, DropPath, MaskDownSampler, CXBlock, Fuser, TwoWayTransformer, TwoWayAttentionBlock, RoPEAttention, MultiScaleAttention, MultiScaleBlock. PositionEmbeddingSine, do_pool +--- + +# Reference for `ultralytics/models/sam2/modules/sam2_blocks.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/sam2_blocks.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/sam2_blocks.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.DropPath + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.MaskDownSampler + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.CXBlock + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.Fuser + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayAttentionBlock + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.TwoWayTransformer + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.RoPEAttention + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleAttention + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.MultiScaleBlock + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.PositionEmbeddingSine + +



+ +## ::: ultralytics.models.sam2.modules.sam2_blocks.do_pool + +

diff --git a/docs/en/reference/models/sam2/modules/utils.md b/docs/en/reference/models/sam2/modules/utils.md new file mode 100644 index 00000000..357cea62 --- /dev/null +++ b/docs/en/reference/models/sam2/modules/utils.md @@ -0,0 +1,44 @@ +--- +description: Explore the detailed API reference for Ultralytics SAM 2 models. +keywords: Ultralytics, SAM 2, API Reference, models, window partition, data processing, YOLO +--- + +# Reference for `ultralytics/models/sam2/modules/utils.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/modules/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/modules/utils.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.modules.utils.select_closest_cond_frames + +



+ +## ::: ultralytics.models.sam2.modules.utils.get_1d_sine_pe + +



+ +## ::: ultralytics.models.sam2.modules.utils.init_t_xy + +



+ +## ::: ultralytics.models.sam2.modules.utils.compute_axial_cis + +



+ +## ::: ultralytics.models.sam2.modules.utils.reshape_for_broadcast + +



+ +## ::: ultralytics.models.sam2.modules.utils.apply_rotary_enc + +



+ +## ::: ultralytics.models.sam2.modules.utils.window_partition + +



+ +## ::: ultralytics.models.sam2.modules.utils.window_unpartition + +

diff --git a/docs/en/reference/models/sam2/predict.md b/docs/en/reference/models/sam2/predict.md new file mode 100644 index 00000000..23b30140 --- /dev/null +++ b/docs/en/reference/models/sam2/predict.md @@ -0,0 +1,16 @@ +--- +description: Explore Ultralytics SAM 2 Predictor for advanced, real-time image segmentation using the Segment Anything Model 2 (SAM 2). Complete implementation details and auxiliary utilities. +keywords: Ultralytics, SAM 2, Segment Anything Model 2, image segmentation, real-time, prediction, AI, machine learning, Python, torch, inference +--- + +# Reference for `ultralytics/models/sam2/predict.py` + +!!! Note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/sam2/predict.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/sam2/predict.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.models.sam2.predict.SAM2Predictor + +

diff --git a/mkdocs.yml b/mkdocs.yml index 768429d4..fcf5516b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -239,7 +239,7 @@ nav: - YOLOv9: models/yolov9.md - YOLOv10: models/yolov10.md - SAM (Segment Anything Model): models/sam.md - - SAM2 (Segment Anything Model 2): models/sam-2.md + - SAM 2 (Segment Anything Model 2): models/sam-2.md - MobileSAM (Mobile Segment Anything Model): models/mobile-sam.md - FastSAM (Fast Segment Anything Model): models/fast-sam.md - YOLO-NAS (Neural Architecture Search): models/yolo-nas.md @@ -509,6 +509,17 @@ nav: - tiny_encoder: reference/models/sam/modules/tiny_encoder.md - transformer: reference/models/sam/modules/transformer.md - predict: reference/models/sam/predict.md + - sam2: + - build: reference/models/sam2/build.md + - model: reference/models/sam2/model.md + - modules: + - decoders: reference/models/sam2/modules/decoders.md + - encoders: reference/models/sam2/modules/encoders.md + - memory_attention: reference/models/sam2/modules/memory_attention.md + - sam2: reference/models/sam2/modules/sam2.md + - sam2_blocks: reference/models/sam2/modules/sam2_blocks.md + - utils: reference/models/sam2/modules/utils.md + - predict: reference/models/sam2/predict.md - utils: - loss: reference/models/utils/loss.md - ops: reference/models/utils/ops.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 8c313528..affb8e35 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.69" +__version__ = "8.2.70" import os @@ -8,7 +8,7 @@ import os os.environ["OMP_NUM_THREADS"] = "1" # reduce CPU utilization during training from ultralytics.data.explorer.explorer import Explorer -from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.models import NAS, RTDETR, SAM, SAM2, YOLO, FastSAM, YOLOWorld from ultralytics.utils import ASSETS, SETTINGS from ultralytics.utils.checks import check_yolo as checks from ultralytics.utils.downloads import download @@ -21,6 +21,7 @@ __all__ = ( "YOLOWorld", "NAS", "SAM", + "SAM2", "FastSAM", "RTDETR", "checks", diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 4f7ad00e..4a23ec42 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -793,6 +793,10 @@ def entrypoint(debug=""): from ultralytics import FastSAM model = FastSAM(model) + elif "sam2" in stem: + from ultralytics import SAM2 + + model = SAM2(model) elif "sam" in stem: from ultralytics import SAM diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index aff620a9..bbade047 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -4,6 +4,7 @@ from .fastsam import FastSAM from .nas import NAS from .rtdetr import RTDETR from .sam import SAM +from .sam2 import SAM2 from .yolo import YOLO, YOLOWorld -__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import +__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "SAM2" # allow simpler import diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py index cd9b3023..2a15c3f3 100644 --- a/ultralytics/models/fastsam/predict.py +++ b/ultralytics/models/fastsam/predict.py @@ -21,6 +21,7 @@ class FastSAMPredictor(SegmentationPredictor): """ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework.""" super().__init__(cfg, overrides, _callbacks) self.prompts = {} diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py index cb3a7c68..253c0159 100644 --- a/ultralytics/models/sam/build.py +++ b/ultralytics/models/sam/build.py @@ -14,7 +14,7 @@ from ultralytics.utils.downloads import attempt_download_asset from .modules.decoders import MaskDecoder from .modules.encoders import ImageEncoderViT, PromptEncoder -from .modules.sam import Sam +from .modules.sam import SAMModel from .modules.tiny_encoder import TinyViT from .modules.transformer import TwoWayTransformer @@ -105,7 +105,7 @@ def _build_sam( out_chans=prompt_embed_dim, ) ) - sam = Sam( + sam = SAMModel( image_encoder=image_encoder, prompt_encoder=PromptEncoder( embed_dim=prompt_embed_dim, diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index feaeb551..a819b6ea 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -44,6 +44,7 @@ class SAM(Model): """ if model and Path(model).suffix not in {".pt", ".pth"}: raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") + self.is_sam2 = "sam2" in Path(model).stem super().__init__(model=model, task="segment") def _load(self, weights: str, task=None): @@ -54,7 +55,12 @@ class SAM(Model): weights (str): Path to the weights file. task (str, optional): Task name. Defaults to None. """ - self.model = build_sam(weights) + if self.is_sam2: + from ..sam2.build import build_sam2 + + self.model = build_sam2(weights) + else: + self.model = build_sam(weights) def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): """ @@ -112,4 +118,6 @@ class SAM(Model): Returns: (dict): A dictionary mapping the 'segment' task to its corresponding 'Predictor'. """ - return {"segment": {"predictor": Predictor}} + from ..sam2.predict import SAM2Predictor + + return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}} diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py index 073b1ad4..eeaab6b4 100644 --- a/ultralytics/models/sam/modules/decoders.py +++ b/ultralytics/models/sam/modules/decoders.py @@ -4,9 +4,8 @@ from typing import List, Tuple, Type import torch from torch import nn -from torch.nn import functional as F -from ultralytics.nn.modules import LayerNorm2d +from ultralytics.nn.modules import MLP, LayerNorm2d class MaskDecoder(nn.Module): @@ -28,7 +27,6 @@ class MaskDecoder(nn.Module): def __init__( self, - *, transformer_dim: int, transformer: nn.Module, num_multimask_outputs: int = 3, @@ -149,42 +147,3 @@ class MaskDecoder(nn.Module): iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred - - -class MLP(nn.Module): - """ - MLP (Multi-Layer Perceptron) model lightly adapted from - https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py - """ - - def __init__( - self, - input_dim: int, - hidden_dim: int, - output_dim: int, - num_layers: int, - sigmoid_output: bool = False, - ) -> None: - """ - Initializes the MLP (Multi-Layer Perceptron) model. - - Args: - input_dim (int): The dimensionality of the input features. - hidden_dim (int): The dimensionality of the hidden layers. - output_dim (int): The dimensionality of the output layer. - num_layers (int): The number of hidden layers. - sigmoid_output (bool, optional): Apply a sigmoid activation to the output layer. Defaults to False. - """ - super().__init__() - self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - self.sigmoid_output = sigmoid_output - - def forward(self, x): - """Executes feedforward within the neural network module and applies activation.""" - for i, layer in enumerate(self.layers): - x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) - if self.sigmoid_output: - x = torch.sigmoid(x) - return x diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py index a51c3472..2bf3945d 100644 --- a/ultralytics/models/sam/modules/encoders.py +++ b/ultralytics/models/sam/modules/encoders.py @@ -211,6 +211,8 @@ class PromptEncoder(nn.Module): point_embedding[labels == -1] += self.not_a_point_embed.weight point_embedding[labels == 0] += self.point_embeddings[0].weight point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight return point_embedding def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: @@ -226,8 +228,8 @@ class PromptEncoder(nn.Module): """Embeds mask inputs.""" return self.mask_downscaling(masks) + @staticmethod def _get_batch_size( - self, points: Optional[Tuple[torch.Tensor, torch.Tensor]], boxes: Optional[torch.Tensor], masks: Optional[torch.Tensor], diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 95d9bbe6..1617527f 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -15,15 +15,14 @@ from .decoders import MaskDecoder from .encoders import ImageEncoderViT, PromptEncoder -class Sam(nn.Module): +class SAMModel(nn.Module): """ - Sam (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate image - embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by the mask - decoder to predict object masks. + SAMModel (Segment Anything Model) is designed for object segmentation tasks. It uses image encoders to generate + image embeddings, and prompt encoders to encode various types of input prompts. These embeddings are then used by + the mask decoder to predict object masks. Attributes: mask_threshold (float): Threshold value for mask prediction. - image_format (str): Format of the input image, default is 'RGB'. image_encoder (ImageEncoderViT): The backbone used to encode the image into embeddings. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts object masks from the image and prompt embeddings. @@ -32,7 +31,6 @@ class Sam(nn.Module): """ mask_threshold: float = 0.0 - image_format: str = "RGB" def __init__( self, @@ -43,7 +41,7 @@ class Sam(nn.Module): pixel_std: List[float] = (58.395, 57.12, 57.375), ) -> None: """ - Initialize the Sam class to predict object masks from an image and input prompts. + Initialize the SAMModel class to predict object masks from an image and input prompts. Note: All forward() operations moved to SAMPredictor. diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py index db684f8f..6375c2ad 100644 --- a/ultralytics/models/sam/modules/transformer.py +++ b/ultralytics/models/sam/modules/transformer.py @@ -86,7 +86,6 @@ class TwoWayTransformer(nn.Module): (torch.Tensor): the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C - bs, c, h, w = image_embedding.shape image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) @@ -212,6 +211,7 @@ class Attention(nn.Module): embedding_dim: int, num_heads: int, downsample_rate: int = 1, + kv_in_dim: int = None, ) -> None: """ Initializes the Attention model with the given dimensions and settings. @@ -226,13 +226,14 @@ class Attention(nn.Module): """ super().__init__() self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim self.internal_dim = embedding_dim // downsample_rate self.num_heads = num_heads assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." self.q_proj = nn.Linear(embedding_dim, self.internal_dim) - self.k_proj = nn.Linear(embedding_dim, self.internal_dim) - self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) self.out_proj = nn.Linear(self.internal_dim, embedding_dim) @staticmethod diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 41076602..e03ba625 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -168,7 +168,7 @@ class Predictor(BasePredictor): - np.ndarray: An array of length C containing quality scores predicted by the model for each mask. - np.ndarray: Low-resolution logits of shape CxHxW for subsequent inference, where H=W=256. """ - features = self.model.image_encoder(im) if self.features is None else self.features + features = self.get_im_features(im) if self.features is None else self.features src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) @@ -334,7 +334,7 @@ class Predictor(BasePredictor): """ device = select_device(self.args.device, verbose=verbose) if model is None: - model = build_sam(self.args.model) + model = self.get_model() model.eval() self.model = model.to(device) self.device = device @@ -348,6 +348,10 @@ class Predictor(BasePredictor): self.model.fp16 = False self.done_warmup = True + def get_model(self): + """Built Segment Anything Model (SAM) model.""" + return build_sam(self.args.model) + def postprocess(self, preds, img, orig_imgs): """ Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. @@ -412,16 +416,18 @@ class Predictor(BasePredictor): AssertionError: If more than one image is set. """ if self.model is None: - model = build_sam(self.args.model) - self.setup_model(model) + self.setup_model(model=None) self.setup_source(image) assert len(self.dataset) == 1, "`set_image` only supports setting one image!" for batch in self.dataset: im = self.preprocess(batch[1]) - self.features = self.model.image_encoder(im) - self.im = im + self.features = self.get_im_features(im) break + def get_im_features(self, im): + """Get image features from the SAM image encoder.""" + return self.model.image_encoder(im) + def set_prompts(self, prompts): """Set prompts in advance.""" self.prompts = prompts diff --git a/ultralytics/models/sam2/__init__.py b/ultralytics/models/sam2/__init__.py new file mode 100644 index 00000000..755a160a --- /dev/null +++ b/ultralytics/models/sam2/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from .model import SAM2 +from .predict import SAM2Predictor + +__all__ = "SAM2", "SAM2Predictor" # tuple or list diff --git a/ultralytics/models/sam2/build.py b/ultralytics/models/sam2/build.py new file mode 100644 index 00000000..2791029a --- /dev/null +++ b/ultralytics/models/sam2/build.py @@ -0,0 +1,156 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch + +from ultralytics.utils.downloads import attempt_download_asset + +from .modules.encoders import FpnNeck, Hiera, ImageEncoder, MemoryEncoder +from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer +from .modules.sam2 import SAM2Model + + +def build_sam2_t(checkpoint=None): + """Build and return a Segment Anything Model (SAM2) tiny-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 7, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[5, 7, 9], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_s(checkpoint=None): + """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 11, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[7, 10, 13], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_b(checkpoint=None): + """Builds and returns a Segment Anything Model (SAM2) base-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=112, + encoder_stages=[2, 3, 16, 3], + encoder_num_heads=2, + encoder_global_att_blocks=[12, 16, 20], + encoder_window_spec=[8, 4, 14, 7], + encoder_window_spatial_size=[14, 14], + encoder_backbone_channel_list=[896, 448, 224, 112], + checkpoint=checkpoint, + ) + + +def build_sam2_l(checkpoint=None): + """Build and return a Segment Anything Model (SAM2) large-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=144, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[23, 33, 43], + encoder_window_spec=[8, 4, 16, 8], + encoder_backbone_channel_list=[1152, 576, 288, 144], + checkpoint=checkpoint, + ) + + +def _build_sam2( + encoder_embed_dim=1280, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[7, 15, 23, 31], + encoder_backbone_channel_list=[1152, 576, 288, 144], + encoder_window_spatial_size=[7, 7], + encoder_window_spec=[8, 4, 16, 8], + checkpoint=None, +): + """Builds a SAM2 model with specified architecture parameters and optional checkpoint loading.""" + image_encoder = ImageEncoder( + trunk=Hiera( + embed_dim=encoder_embed_dim, + num_heads=encoder_num_heads, + stages=encoder_stages, + global_att_blocks=encoder_global_att_blocks, + window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, + window_spec=encoder_window_spec, + ), + neck=FpnNeck( + d_model=256, + backbone_channel_list=encoder_backbone_channel_list, + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + scalp=1, + ) + memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) + memory_encoder = MemoryEncoder(out_dim=64) + + sam2 = SAM2Model( + image_encoder=image_encoder, + memory_attention=memory_attention, + memory_encoder=memory_encoder, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, + sam_mask_decoder_extra_args=dict( + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + ), + ) + + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f)["model"] + sam2.load_state_dict(state_dict) + sam2.eval() + return sam2 + + +sam_model_map = { + "sam2_t.pt": build_sam2_t, + "sam2_s.pt": build_sam2_s, + "sam2_b.pt": build_sam2_b, + "sam2_l.pt": build_sam2_l, +} + + +def build_sam2(ckpt="sam_b.pt"): + """Constructs a Segment Anything Model (SAM2) based on the specified checkpoint, with various size options.""" + model_builder = None + ckpt = str(ckpt) # to allow Path ckpt types + for k in sam_model_map.keys(): + if ckpt.endswith(k): + model_builder = sam_model_map.get(k) + + if not model_builder: + raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") + + return model_builder(ckpt) diff --git a/ultralytics/models/sam2/model.py b/ultralytics/models/sam2/model.py new file mode 100644 index 00000000..4b3265e1 --- /dev/null +++ b/ultralytics/models/sam2/model.py @@ -0,0 +1,97 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +""" +SAM2 model interface. + +This module provides an interface to the Segment Anything Model (SAM2) from Ultralytics, designed for real-time image +segmentation tasks. The SAM2 model allows for promptable segmentation with unparalleled versatility in image analysis, +and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new +image distributions and tasks without prior knowledge. + +Key Features: + - Promptable segmentation + - Real-time performance + - Zero-shot transfer capabilities + - Trained on SA-1B dataset +""" + +from ultralytics.models.sam import SAM + +from .build import build_sam2 +from .predict import SAM2Predictor + + +class SAM2(SAM): + """ + SAM2 class for real-time image segmentation using the Segment Anything Model (SAM2). + + This class extends the SAM base class, providing an interface to the SAM2 model for promptable segmentation + tasks. It supports loading pre-trained weights and offers zero-shot performance capabilities. + + Attributes: + model (torch.nn.Module): The loaded SAM2 model. + task_map (Dict[str, Type[SAM2Predictor]]): Mapping of 'segment' task to SAM2Predictor. + + Methods: + __init__: Initializes the SAM2 model with pre-trained weights. + _load: Loads specified weights into the SAM2 model. + + Examples: + >>> sam2 = SAM2("sam2_b.pt") + >>> sam2._load('path/to/sam2_weights.pt') + >>> task_map = sam2.task_map + >>> print(task_map) + {'segment': SAM2Predictor} + + Notes: + - Supports .pt and .pth file extensions for model weights. + - Offers zero-shot transfer capabilities for new image distributions and tasks. + """ + + def __init__(self, model="sam2_b.pt") -> None: + """ + Initializes the SAM2 model with a pre-trained model file. + + Args: + model (str): Path to the pre-trained SAM2 model file. File should have a .pt or .pth extension. + + Raises: + NotImplementedError: If the model file extension is not .pt or .pth. + + Examples: + >>> sam2 = SAM2("sam2_b.pt") + """ + super().__init__(model=model) + + def _load(self, weights: str, task=None): + """ + Loads the specified weights into the SAM2 model. + + This method is responsible for loading pre-trained weights into the SAM2 model. It supports loading + weights from files with .pt or .pth extensions. + + Args: + weights (str): Path to the weights file. Should be a file with .pt or .pth extension. + task (str | None): Task name. If provided, it may be used to configure model-specific settings. + + Examples: + >>> sam2_model = SAM2() + >>> sam2_model._load('path/to/sam2_weights.pt') + """ + self.model = build_sam2(weights) + + @property + def task_map(self): + """ + Provides a mapping from the 'segment' task to its corresponding 'Predictor'. + + Returns: + (Dict[str, Type[SAM2Predictor]]): A dictionary mapping the 'segment' task to its corresponding + SAM2Predictor class. + + Examples: + >>> sam2 = SAM2() + >>> task_map = sam2.task_map + >>> print(task_map) + {'segment': SAM2Predictor} + """ + return {"segment": {"predictor": SAM2Predictor}} diff --git a/ultralytics/models/sam2/modules/__init__.py b/ultralytics/models/sam2/modules/__init__.py new file mode 100644 index 00000000..9e68dc12 --- /dev/null +++ b/ultralytics/models/sam2/modules/__init__.py @@ -0,0 +1 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license diff --git a/ultralytics/models/sam2/modules/decoders.py b/ultralytics/models/sam2/modules/decoders.py new file mode 100644 index 00000000..ac6c60a6 --- /dev/null +++ b/ultralytics/models/sam2/modules/decoders.py @@ -0,0 +1,305 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from ultralytics.nn.modules import MLP, LayerNorm2d + + +class MaskDecoder(nn.Module): + """Transformer-based decoder predicting instance segmentation masks from image and prompt embeddings.""" + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Initializes the MaskDecoder module for predicting instance segmentation masks. + + Args: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + use_high_res_features (bool): Whether to use high-resolution features. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + pred_obj_scores (bool): Whether to predict object scores. + pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + + Attributes: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + iou_token (nn.Embedding): Embedding for IOU token. + num_mask_tokens (int): Total number of mask tokens. + mask_tokens (nn.Embedding): Embedding for mask tokens. + pred_obj_scores (bool): Whether to predict object scores. + obj_score_token (nn.Embedding): Embedding for object score token. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + output_upscaling (nn.Sequential): Upscaling layers for output. + use_high_res_features (bool): Whether to use high-resolution features. + conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). + conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). + output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. + iou_prediction_head (MLP): MLP for IOU prediction. + pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) + + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder. + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. + multimask_output (bool): Whether to return multiple masks or a single mask. + repeat_image (bool): Flag to repeat the image embeddings. + high_res_features (List[torch.Tensor] | None): Optional high-resolution features. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: + - masks (torch.Tensor): Batched predicted masks. + - iou_pred (torch.Tensor): Batched predictions of mask quality. + - sam_tokens_out (torch.Tensor): Batched SAM token for mask output. + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out = decoder.forward(image_embeddings, image_pe, + ... sparse_prompt_embeddings, dense_prompt_embeddings, True, False) + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts instance segmentation masks from image and prompt embeddings using a transformer architecture.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """Computes mask stability scores based on IoU between upper and lower thresholds.""" + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + Dynamically selects the most stable mask output based on stability scores and IoU predictions. + + When outputting a single mask, if the stability score from the current single-mask output (based on output token + 0) falls below a threshold, we instead select from multi-mask outputs (based on output token 1~3) the mask with + the highest predicted IoU score. + + This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/ultralytics/models/sam2/modules/encoders.py b/ultralytics/models/sam2/modules/encoders.py new file mode 100644 index 00000000..b4cd0f8f --- /dev/null +++ b/ultralytics/models/sam2/modules/encoders.py @@ -0,0 +1,332 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.models.sam.modules.encoders import PatchEmbed + +from .sam2_blocks import CXBlock, Fuser, MaskDownSampler, MultiScaleBlock, PositionEmbeddingSine + + +class MemoryEncoder(nn.Module): + """Encodes pixel features and masks into a memory representation for efficient image segmentation.""" + + def __init__( + self, + out_dim, + in_dim=256, # in_dim of pix_feats + ): + """Initializes the MemoryEncoder module for encoding pixel features and masks in SAM-like models.""" + super().__init__() + + self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = Fuser(CXBlock(dim=256), num_layers=2) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Processes pixel features and masks, fusing them to generate encoded memory representations.""" + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} + + +class ImageEncoder(nn.Module): + """Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings.""" + + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + """Initializes an image encoder with a trunk, neck, and optional scalp for feature extraction.""" + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." + + def forward(self, sample: torch.Tensor): + """Processes image input through trunk and neck, returning features, positional encodings, and FPN outputs.""" + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.""" + + def __init__( + self, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """ + Initializes a modified Feature Pyramid Network (FPN) neck. + + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. + + Args: + d_model (int): Dimension of the model. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + kernel_size (int): Kernel size for the convolutional layers. + stride (int): Stride for the convolutional layers. + padding (int): Padding for the convolutional layers. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. + + Attributes: + position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding. + convs (nn.ModuleList): List of convolutional layers for each backbone level. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion. + fpn_top_down_levels (List[int]): Levels with top-down feature propagation. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> print(fpn_neck) + """ + super().__init__() + self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + """ + Performs forward pass through the Feature Pyramid Network (FPN) neck. + + Args: + xs (List[torch.Tensor]): List of input tensors from the backbone, with shape (B, C, H, W) for each tensor. + + Returns: + (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing two lists: + - out: List of output feature maps after FPN processing, with shape (B, d_model, H, W) for each tensor. + - pos: List of positional encodings corresponding to each output feature map. + + Examples: + >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) + >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] + >>> outputs, positions = fpn_neck(inputs) + """ + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=(None if self.fpn_interp_model == "nearest" else False), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos + + +class Hiera(nn.Module): + """Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.""" + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + """Initializes a Hiera model with configurable architecture for hierarchical vision transformers.""" + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + kernel_size=(7, 7), + stride=(4, 4), + padding=(3, 3), + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + """Generate positional embeddings by interpolating and combining window and background embeddings.""" + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Performs hierarchical vision transformer forward pass, returning multiscale feature maps.""" + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/ultralytics/models/sam2/modules/memory_attention.py b/ultralytics/models/sam2/modules/memory_attention.py new file mode 100644 index 00000000..8b5673c3 --- /dev/null +++ b/ultralytics/models/sam2/modules/memory_attention.py @@ -0,0 +1,170 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import copy +from typing import Optional + +import torch +from torch import Tensor, nn + +from .sam2_blocks import RoPEAttention + + +class MemoryAttentionLayer(nn.Module): + """Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.""" + + def __init__( + self, + d_model: int = 256, + dim_feedforward: int = 2048, + dropout: float = 0.1, + pos_enc_at_attn: bool = False, + pos_enc_at_cross_attn_keys: bool = True, + pos_enc_at_cross_attn_queries: bool = False, + ): + """Initializes a MemoryAttentionLayer with self-attention, cross-attention, and feedforward components.""" + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1) + self.cross_attn_image = RoPEAttention( + rope_k_repeat=True, + embedding_dim=256, + num_heads=1, + downsample_rate=1, + kv_in_dim=64, + ) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = nn.ReLU() + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism.""" + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + """Performs cross-attention between target and memory tensors using RoPEAttention mechanism.""" + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + """Performs self-attention, cross-attention, and MLP operations on input tensors for memory-based attention.""" + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + """Memory attention module for processing sequential data with self and cross-attention mechanisms.""" + + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + """Initializes MemoryAttention module with layers and normalization for attention processing.""" + super().__init__() + self.d_model = d_model + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + """Applies self-attention and cross-attention to input tensors, processing through multiple layers.""" + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/ultralytics/models/sam2/modules/sam2.py b/ultralytics/models/sam2/modules/sam2.py new file mode 100644 index 00000000..363241a1 --- /dev/null +++ b/ultralytics/models/sam2/modules/sam2.py @@ -0,0 +1,804 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch +import torch.distributed +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from ultralytics.models.sam.modules.encoders import PromptEncoder +from ultralytics.nn.modules import MLP + +from .decoders import MaskDecoder +from .sam2_blocks import TwoWayTransformer +from .utils import get_1d_sine_pe, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Model(torch.nn.Module): + """SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities.""" + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + """Initializes SAM2Model model with image encoder, memory attention, and memory encoder components.""" + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print("Image encoder compilation is enabled. First forward pass will be slow.") + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + """Returns the device on which the model's parameters are stored.""" + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + """Processes input frames and prompts to generate object masks and scores in video sequences.""" + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Args: + backbone_features (torch.Tensor): Image features with shape (B, C, H, W). + point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. + 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute + pixel-unit coordinates in (x, y) format for P input points. + 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, + 0 means negative clicks, and -1 means padding. + mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the + same spatial size as the image. + high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes + (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps + for SAM decoder. + multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, + output only 1 mask and its IoU estimate. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): + low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. + high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. + ious: Tensor of shape (B, M) with estimated IoU for each output mask. + low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. + high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. + obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. + object_score_logits: Tensor of shape (B,) with object score logits. + + Where M is 3 if multimask_output=True, and 1 if multimask_output=False. + + Examples: + >>> backbone_features = torch.rand(1, 256, 32, 32) + >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} + >>> mask_inputs = torch.rand(1, 1, 512, 512) + >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) + >>> low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits = results + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """Processes mask inputs to generate output mask logits and object pointers without using SAM.""" + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Process image batch through encoder to extract multi-level features for SAM model.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features from the image backbone output.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].cuda(non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encodes the current frame's features and predicted masks into a new memory representation.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, + mask_for_mem, + skip_mask_sigmoid=True, # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """Applies non-overlapping constraints to object masks, keeping highest scoring object at each location.""" + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/ultralytics/models/sam2/modules/sam2_blocks.py b/ultralytics/models/sam2/modules/sam2_blocks.py new file mode 100644 index 00000000..67dc587e --- /dev/null +++ b/ultralytics/models/sam2/modules/sam2_blocks.py @@ -0,0 +1,715 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import copy +import math +from functools import partial +from typing import Optional, Tuple, Type, Union + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ultralytics.models.sam.modules.transformer import ( + Attention, +) +from ultralytics.models.sam.modules.transformer import ( + TwoWayAttentionBlock as SAMTwoWayAttentionBlock, +) +from ultralytics.models.sam.modules.transformer import ( + TwoWayTransformer as SAMTwoWayTransformer, +) +from ultralytics.nn.modules import MLP, LayerNorm2d + +from .utils import apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition + + +class DropPath(nn.Module): + """Implements stochastic depth regularization for neural networks during training.""" + + def __init__(self, drop_prob=0.0, scale_by_keep=True): + """Initialize DropPath module with specified drop probability and scaling option.""" + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Applies stochastic depth to input tensor during training, with optional scaling.""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class MaskDownSampler(nn.Module): + """Downsamples and embeds masks using convolutional layers and layer normalization for efficient processing.""" + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + """Initializes a mask downsampler module for progressive downsampling and channel expansion.""" + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + """ + ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering two equivalent + implementations for improved performance and flexibility. + + Attributes: + dwconv (nn.Conv2d): Depthwise convolution layer. + norm (LayerNorm2d): Layer normalization applied to channels. + pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. + act (nn.GELU): GELU activation function. + pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. + gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. + drop_path (nn.Module): DropPath layer for stochastic depth regularization. + + Methods: + forward: Processes the input tensor through the ConvNeXt block. + + Examples: + >>> import torch + >>> x = torch.randn(1, 64, 56, 56) + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + """ + Initialize a ConvNeXt Block. + + This block implements a ConvNeXt architecture with optional depthwise convolution, layer normalization, + pointwise convolutions, and GELU activation. + + Args: + dim (int): Number of input channels. + kernel_size (int): Size of the convolutional kernel. Default is 7. + padding (int): Padding size for the convolution. Default is 3. + drop_path (float): Stochastic depth rate. Default is 0.0. + layer_scale_init_value (float): Initial value for Layer Scale. Default is 1e-6. + use_dwconv (bool): Whether to use depthwise convolution. Default is True. + + Attributes: + dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. + norm (LayerNorm2d): Layer normalization applied to the output of dwconv. + pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. + act (nn.GELU): GELU activation function. + pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. + gamma (nn.Parameter | None): Learnable scale parameter for the residual path. + + Examples: + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 32, 32]) + """ + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + """ + A module for fusing features through multiple layers of a neural network. + + This class applies a series of identical layers to an input tensor, optionally projecting the input first. + + Attributes: + proj (nn.Module): An optional input projection layer. Identity if no projection is needed. + layers (nn.ModuleList): A list of identical layers to be applied sequentially. + + Methods: + forward: Applies the fuser to an input tensor. + + Examples: + >>> layer = CXBlock(dim=256) + >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = fuser(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, layer, num_layers, dim=None, input_projection=False): + """ + Initializes the Fuser module. + + This module creates a sequence of identical layers and optionally applies an input projection. + + Args: + layer (nn.Module): The layer to be replicated in the fuser. + num_layers (int): The number of times to replicate the layer. + dim (int | None): The dimension for input projection, if used. + input_projection (bool): Whether to use input projection. + + Attributes: + proj (nn.Module): The input projection layer, or nn.Identity if not used. + layers (nn.ModuleList): A list of replicated layers. + + Examples: + >>> layer = nn.Linear(64, 64) + >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) + >>> input_tensor = torch.randn(1, 64) + >>> output = fuser(input_tensor) + """ + super().__init__() + self.proj = nn.Identity() + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + """Applies a series of layers to the input tensor, optionally projecting it first.""" + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class TwoWayAttentionBlock(SAMTwoWayAttentionBlock): + """ + A two-way attention block for performing self-attention and cross-attention in both directions. + + This block extends the SAMTwoWayAttentionBlock and consists of four main components: self-attention on + sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and + cross-attention from dense to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after the second attention block. + mlp (MLP): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after the MLP block. + norm4 (nn.LayerNorm): Layer normalization after the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. + + Methods: + forward: Processes input through the attention blocks and MLP. + + Examples: + >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8) + >>> sparse_input = torch.randn(1, 100, 256) + >>> dense_input = torch.randn(1, 256, 16, 16) + >>> sparse_output, dense_output = block(sparse_input, dense_input) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initializes a TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. + + This block consists of four main layers: self-attention on sparse inputs, cross-attention of sparse inputs + to dense inputs, an MLP block on sparse inputs, and cross-attention of dense inputs to sparse inputs. + + Args: + embedding_dim (int): The channel dimension of the embeddings. + num_heads (int): The number of heads in the attention layers. + mlp_dim (int): The hidden dimension of the MLP block. + activation (Type[nn.Module]): The activation function of the MLP block. + attention_downsample_rate (int): The downsample rate for attention computations. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + + Attributes: + self_attn (Attention): The self-attention layer for the queries. + norm1 (nn.LayerNorm): Layer normalization following the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization following the second attention block. + mlp (MLP): MLP block that transforms the query embeddings. + norm3 (nn.LayerNorm): Layer normalization following the MLP block. + norm4 (nn.LayerNorm): Layer normalization following the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + + Examples: + >>> block = TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> sparse_inputs = torch.randn(1, 100, 256) + >>> dense_inputs = torch.randn(1, 256, 32, 32) + >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) + """ + super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) + self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) + + +class TwoWayTransformer(SAMTwoWayTransformer): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class implements a specialized transformer decoder that attends to an input image using queries with + supplied positional embeddings. It is particularly useful for tasks like object detection, image + segmentation, and point cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes input image embeddings and query embeddings through the transformer. + + Examples: + >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 64, 64) + >>> query_embedding = torch.randn(1, 100, 256) + >>> output = transformer(image_embedding, query_embedding) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initializes a TwoWayTransformer instance. + + This transformer decoder attends to an input image using queries with supplied positional embeddings. + It is designed for tasks like object detection, image segmentation, and point cloud processing. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for the input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Channel dimension internal to the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention computations. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for the input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers comprising the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to the final queries. + + Examples: + >>> transformer = TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> transformer + TwoWayTransformer( + (layers): ModuleList( + (0-4): 5 x TwoWayAttentionBlock(...) + ) + (final_attn_token_to_image): Attention(...) + (norm_final_attn): LayerNorm(...) + ) + """ + super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) + self.layers = nn.ModuleList() + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + +class RoPEAttention(Attention): + """Implements rotary position encoding for attention mechanisms in transformer architectures.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + """Initializes RoPEAttention with rotary position encoding for attention mechanisms.""" + super().__init__(*args, **kwargs) + + self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: + """Applies rotary position encoding and computes attention between query, key, and value tensors.""" + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + """Applies pooling and optional normalization to a tensor, handling permutations for spatial operations.""" + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """Implements multi-scale self-attention with optional query pooling for efficient feature extraction.""" + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + """Initializes a multi-scale attention module with configurable query pooling and linear projections.""" + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-scale attention to input tensor, optionally downsampling query features.""" + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """Multiscale attention block with window partitioning and query pooling for efficient vision transformers.""" + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + """Initializes a multi-scale attention block with optional window partitioning and downsampling.""" + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + act=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-scale attention and MLP processing to input tensor, with optional windowing.""" + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PositionEmbeddingSine(nn.Module): + """Generates sinusoidal positional embeddings for 2D inputs like images.""" + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + """Initializes sinusoidal position embeddings for 2D image inputs.""" + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + """Encodes 2D positions using sine and cosine functions for positional embeddings.""" + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + """Encodes box coordinates and dimensions into positional embeddings for object detection tasks.""" + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + """Encodes 2D point coordinates with sinusoidal positional embeddings and appends labels.""" + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """Generate sinusoidal position embeddings for 2D inputs.""" + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos diff --git a/ultralytics/models/sam2/modules/utils.py b/ultralytics/models/sam2/modules/utils.py new file mode 100644 index 00000000..b09dd9b2 --- /dev/null +++ b/ultralytics/models/sam2/modules/utils.py @@ -0,0 +1,191 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Selects the closest conditioning frames to a given frame index. + + Args: + frame_idx (int): Current frame index. + cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices. + max_cond_frame_num (int): Maximum number of conditioning frames to select. + + Returns: + (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries: + - selected_outputs: Selected items from cond_frame_outputs. + - unselected_outputs: Items not selected from cond_frame_outputs. + + Examples: + >>> frame_idx = 5 + >>> cond_frame_outputs = {1: 'a', 3: 'b', 7: 'c', 9: 'd'} + >>> max_cond_frame_num = 2 + >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num) + >>> print(selected) + {3: 'b', 7: 'c'} + >>> print(unselected) + {1: 'a', 9: 'd'} + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs} + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """Generates 1D sinusoidal positional embeddings for given positions and dimensions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def init_t_xy(end_x: int, end_y: int): + """Initializes 1D and 2D coordinate tensors for a grid of size end_x by end_y.""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + """Computes axial complex exponential positional encodings for 2D spatial positions.""" + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """Reshapes frequency tensor for broadcasting, ensuring compatibility with input tensor dimensions.""" + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + """Applies rotary positional encoding to query and key tensors using complex-valued frequency components.""" + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +def window_partition(x, window_size): + """ + Partitions input tensor into non-overlapping windows with padding if needed. + + Args: + x (torch.Tensor): Input tensor with shape (B, H, W, C). + window_size (int): Size of each window. + + Returns: + (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing: + - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C). + - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition. + + Examples: + >>> x = torch.randn(1, 16, 16, 3) + >>> windows, (Hp, Wp) = window_partition(x, window_size=4) + >>> print(windows.shape, Hp, Wp) + torch.Size([16, 4, 4, 3]) 16 16 + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Unpartitions windowed sequences into original sequences and removes padding. + + This function reverses the windowing process, reconstructing the original input from windowed segments + and removing any padding that was added during the windowing process. + + Args: + windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size, + window_size, C), where B is the batch size, num_windows is the number of windows, window_size is + the size of each window, and C is the number of channels. + window_size (int): Size of each window. + pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing. + hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing. + + Returns: + (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W + are the original height and width, and C is the number of channels. + + Examples: + >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels + >>> pad_hw = (16, 16) # Padded height and width + >>> hw = (15, 14) # Original height and width + >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw) + >>> print(x.shape) + torch.Size([1, 15, 14, 64]) + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x diff --git a/ultralytics/models/sam2/predict.py b/ultralytics/models/sam2/predict.py new file mode 100644 index 00000000..ca9438a2 --- /dev/null +++ b/ultralytics/models/sam2/predict.py @@ -0,0 +1,182 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import torch + +from ..sam.predict import Predictor +from .build import build_sam2 + + +class SAM2Predictor(Predictor): + """ + A predictor class for the Segment Anything Model 2 (SAM2), extending the base Predictor class. + + This class provides an interface for model inference tailored to image segmentation tasks, leveraging SAM2's + advanced architecture and promptable segmentation capabilities. It facilitates flexible and real-time mask + generation, working with various types of prompts such as bounding boxes, points, and low-resolution masks. + + Attributes: + cfg (Dict): Configuration dictionary specifying model and task-related parameters. + overrides (Dict): Dictionary containing values that override the default configuration. + _callbacks (Dict): Dictionary of user-defined callback functions to augment behavior. + args (namespace): Namespace to hold command-line arguments or other operational variables. + im (torch.Tensor): Preprocessed input image tensor. + features (torch.Tensor): Extracted image features used for inference. + prompts (Dict): Collection of various prompt types, such as bounding boxes and points. + segment_all (bool): Flag to control whether to segment all objects in the image or only specified ones. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + _bb_feat_sizes (List[Tuple[int, int]]): List of feature sizes for different backbone levels. + + Methods: + get_model: Builds and returns the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts image features from the SAM2 image encoder. + + Examples: + >>> predictor = SAM2Predictor(model='sam2_l.pt') + >>> predictor.set_image('path/to/image.jpg') + >>> masks, scores = predictor.prompt_inference(im=predictor.im, points=[[500, 375]], labels=[1]) + >>> print(f"Generated {len(masks)} mask(s) with scores: {scores}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam2(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (tuple): Tuple containing: + - np.ndarray: Output masks with shape (C, H, W), where C is the number of generated masks. + - np.ndarray: Quality scores for each mask, with length C. + - np.ndarray: Low-resolution logits with shape (C, 256, 256) for subsequent inference. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(image, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + src_shape, dst_shape = self.batch[1][0].shape[:2], im.shape[2:] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = torch.ones(points.shape[0]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + points *= r + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + + points = (points, labels) if points is not None else None + # TODO: Embed prompts + # if bboxes is not None: + # box_coords = bboxes.reshape(-1, 2, 2) + # box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=bboxes.device) + # box_labels = box_labels.repeat(bboxes.size(0), 1) + # # we merge "boxes" and "points" into a single "concat_points" input (where + # # boxes are added at the beginning) to sam_prompt_encoder + # if concat_points is not None: + # concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + # concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + # concat_points = (concat_coords, concat_labels) + # else: + # concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=bboxes, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This function sets up the model if not already initialized, configures the data source to the specified image, + and preprocesses the image for feature extraction. Only one image can be set at a time. + + Args: + image (str | np.ndarray): Image file path as a string, or a numpy array image read by cv2. + + Raises: + AssertionError: If more than one image is set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks.""" + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index 062c6094..4c170aa3 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -174,18 +174,20 @@ class MLPBlock(nn.Module): class MLP(nn.Module): """Implements a simple multi-layer perceptron (also called FFN).""" - def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False): """Initialize the MLP with specified input, hidden, output dimensions and number of layers.""" super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid = sigmoid + self.act = act() def forward(self, x): """Forward pass for the entire MLP.""" for i, layer in enumerate(self.layers): - x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) - return x + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.sigmoid() if self.sigmoid else x class LayerNorm2d(nn.Module):