ultralytics 8.2.69 FastSAM prompt inference refactor (#14724)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
82c4bdad10
commit
9532ad7cae
11 changed files with 187 additions and 427 deletions
|
|
@ -8,11 +8,11 @@ keywords: IBM Watsonx, IBM Watsonx AI, What is Watson?, IBM Watson Integration,
|
|||
|
||||
Nowadays, scalable [computer vision solutions](../guides/steps-of-a-cv-project.md) are becoming more common and transforming the way we handle visual data. A great example is IBM Watsonx, an advanced AI and data platform that simplifies the development, deployment, and management of AI models. It offers a complete suite for the entire AI lifecycle and seamless integration with IBM Cloud services.
|
||||
|
||||
You can train [Ultralytics YOLOv8 models](https://github.com/ultralytics/ultralytics) using IBM Watsonx. It's a good option for enterprises interested in efficient [model training](../modes/train.md), fine-tuning for specific tasks, and improving [model performance](../guides/model-evaluation-insights.md) with robust tools and a user-friendly setup. In this guide, we'll walk you through the process of training YOLOv8 with IBM Watsonx, covering everything from setting up your environment to evaluating your trained models. Let’s get started!
|
||||
You can train [Ultralytics YOLOv8 models](https://github.com/ultralytics/ultralytics) using IBM Watsonx. It's a good option for enterprises interested in efficient [model training](../modes/train.md), fine-tuning for specific tasks, and improving [model performance](../guides/model-evaluation-insights.md) with robust tools and a user-friendly setup. In this guide, we'll walk you through the process of training YOLOv8 with IBM Watsonx, covering everything from setting up your environment to evaluating your trained models. Let's get started!
|
||||
|
||||
## What is IBM Watsonx?
|
||||
|
||||
[Watsonx](https://www.ibm.com/watsonx) is IBM's cloud-based platform designed for commercial generative AI and scientific data. IBM Watsonx’s three components - watsonx.ai, watsonx.data, and watsonx.governance - come together to create an end-to-end, trustworthy AI platform that can accelerate AI projects aimed at solving business problems. It provides powerful tools for building, training, and [deploying machine learning models](../guides/model-deployment-options.md) and makes it easy to connect with various data sources.
|
||||
[Watsonx](https://www.ibm.com/watsonx) is IBM's cloud-based platform designed for commercial generative AI and scientific data. IBM Watsonx's three components - watsonx.ai, watsonx.data, and watsonx.governance - come together to create an end-to-end, trustworthy AI platform that can accelerate AI projects aimed at solving business problems. It provides powerful tools for building, training, and [deploying machine learning models](../guides/model-deployment-options.md) and makes it easy to connect with various data sources.
|
||||
|
||||
<p align="center">
|
||||
<img width="800" src="https://cdn.stackoverflow.co/images/jo7n4k8s/production/48b67e6aec41f89031a3426cbd1f78322e6776cb-8800x4950.jpg?auto=format" alt="Overview of IBM Watsonx">
|
||||
|
|
@ -22,7 +22,7 @@ Its user-friendly interface and collaborative capabilities streamline the develo
|
|||
|
||||
## Key Features of IBM Watsonx
|
||||
|
||||
IBM Watsonx is made of three main components: watsonx.ai, watsonx.data, and watsonx.governance. Each component offers features that cater to different aspects of AI and data management. Let’s take a closer look at them.
|
||||
IBM Watsonx is made of three main components: watsonx.ai, watsonx.data, and watsonx.governance. Each component offers features that cater to different aspects of AI and data management. Let's take a closer look at them.
|
||||
|
||||
### [Watsonx.ai](https://www.ibm.com/products/watsonx-ai)
|
||||
|
||||
|
|
@ -42,11 +42,11 @@ You can use IBM Watsonx to accelerate your YOLOv8 model training workflow.
|
|||
|
||||
### Prerequisites
|
||||
|
||||
You need an [IBM Cloud account](https://cloud.ibm.com/registration) to create a [watsonx.ai](https://www.ibm.com/products/watsonx-ai) project, and you’ll also need a [Kaggle](./kaggle.md) account to load the data set.
|
||||
You need an [IBM Cloud account](https://cloud.ibm.com/registration) to create a [watsonx.ai](https://www.ibm.com/products/watsonx-ai) project, and you'll also need a [Kaggle](./kaggle.md) account to load the data set.
|
||||
|
||||
### Step 1: Set Up Your Environment
|
||||
|
||||
First, you’ll need to set up an IBM account to use a Jupyter Notebook. Log in to [watsonx.ai](https://eu-de.dataplatform.cloud.ibm.com/registration/stepone?preselect_region=true) using your IBM Cloud account.
|
||||
First, you'll need to set up an IBM account to use a Jupyter Notebook. Log in to [watsonx.ai](https://eu-de.dataplatform.cloud.ibm.com/registration/stepone?preselect_region=true) using your IBM Cloud account.
|
||||
|
||||
Then, create a [watsonx.ai project](https://www.ibm.com/docs/en/watsonx/saas?topic=projects-creating-project), and a [Jupyter Notebook](https://www.ibm.com/docs/en/watsonx/saas?topic=editor-creating-managing-notebooks).
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ Then, you can import the needed packages.
|
|||
|
||||
For this tutorial, we will use a [marine litter dataset](https://www.kaggle.com/datasets/atiqishrak/trash-dataset-icra19) available on Kaggle. With this dataset, we will custom-train a YOLOv8 model to detect and classify litter and biological objects in underwater images.
|
||||
|
||||
We can load the dataset directly into the notebook using the Kaggle API. First, create a free Kaggle account. Once you have created an account, you’ll need to generate an API key. Directions for generating your key can be found in the [Kaggle API documentation](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md) under the section "API credentials".
|
||||
We can load the dataset directly into the notebook using the Kaggle API. First, create a free Kaggle account. Once you have created an account, you'll need to generate an API key. Directions for generating your key can be found in the [Kaggle API documentation](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md) under the section "API credentials".
|
||||
|
||||
Copy and paste your Kaggle username and API key into the following code. Then run the code to install the API and load the dataset into Watsonx.
|
||||
|
||||
|
|
@ -248,7 +248,7 @@ Run the following command-line code to fine tune a pretrained default YOLOv8 mod
|
|||
!yolo task=detect mode=train data={work_dir}/trash_ICRA19/config.yaml model=yolov8s.pt epochs=2 batch=32 lr0=.04 plots=True
|
||||
```
|
||||
|
||||
Here’s a closer look at the parameters in the model training command:
|
||||
Here's a closer look at the parameters in the model training command:
|
||||
|
||||
- **task**: It specifies the computer vision task for which you are using the specified YOLO model and data set.
|
||||
- **mode**: Denotes the purpose for which you are loading the specified model and data. Since we are training a model, it is set to "train." Later, when we test our model's performance, we will set it to "predict."
|
||||
|
|
@ -257,7 +257,7 @@ Here’s a closer look at the parameters in the model training command:
|
|||
- **lr0**: Specifies the model's initial learning rate.
|
||||
- **plots**: Directs YOLO to generate and save plots of our model's training and evaluation metrics.
|
||||
|
||||
For a detailed understanding of the model training process and best practices, refer to the [YOLOv8 Model Training guide](../modes/train.md). This guide will help you get the most out of your experiments and ensure you’re using YOLOv8 effectively.
|
||||
For a detailed understanding of the model training process and best practices, refer to the [YOLOv8 Model Training guide](../modes/train.md). This guide will help you get the most out of your experiments and ensure you're using YOLOv8 effectively.
|
||||
|
||||
### Step 6: Test the Model
|
||||
|
||||
|
|
@ -312,7 +312,7 @@ Unlike precision, recall moves in the opposite direction, showing greater recall
|
|||
|
||||
### Step 8: Calculating Intersection Over Union
|
||||
|
||||
You can measure the prediction accuracy by calculating the IoU between a predicted bounding box and a ground truth bounding box for the same object. Check out [IBM’s tutorial on training YOLOv8](https://developer.ibm.com/tutorials/awb-train-yolo-object-detection-model-in-python/) for more details.
|
||||
You can measure the prediction accuracy by calculating the IoU between a predicted bounding box and a ground truth bounding box for the same object. Check out [IBM's tutorial on training YOLOv8](https://developer.ibm.com/tutorials/awb-train-yolo-object-detection-model-in-python/) for more details.
|
||||
|
||||
## Summary
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ To perform object detection on an image, use the `predict` method as shown below
|
|||
|
||||
```python
|
||||
from ultralytics import FastSAM
|
||||
from ultralytics.models.fastsam import FastSAMPrompt
|
||||
|
||||
# Define an inference source
|
||||
source = "path/to/bus.jpg"
|
||||
|
|
@ -77,23 +76,17 @@ To perform object detection on an image, use the `predict` method as shown below
|
|||
# Run inference on an image
|
||||
everything_results = model(source, device="cpu", retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
|
||||
|
||||
# Prepare a Prompt Process object
|
||||
prompt_process = FastSAMPrompt(source, everything_results, device="cpu")
|
||||
# Run inference with bboxes prompt
|
||||
results = model(source, bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Everything prompt
|
||||
results = prompt_process.everything_prompt()
|
||||
# Run inference with points prompt
|
||||
results = model(source, points=[[200, 200]], labels=[1])
|
||||
|
||||
# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
|
||||
results = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
|
||||
# Run inference with texts prompt
|
||||
results = model(source, texts="a photo of a dog")
|
||||
|
||||
# Text prompt
|
||||
results = prompt_process.text_prompt(text="a photo of a dog")
|
||||
|
||||
# Point prompt
|
||||
# points default [[0,0]] [[x1,y1],[x2,y2]]
|
||||
# point_label default [0] [1,0] 0:background, 1:foreground
|
||||
results = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
|
||||
prompt_process.plot(annotations=results, output="./")
|
||||
# Run inference with bboxes and points and texts prompt at the same time
|
||||
results = model(source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog")
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
|
|
@ -105,6 +98,28 @@ To perform object detection on an image, use the `predict` method as shown below
|
|||
|
||||
This snippet demonstrates the simplicity of loading a pre-trained model and running a prediction on an image.
|
||||
|
||||
!!! Example "FastSAMPredictor example"
|
||||
|
||||
This way you can run inference on image and get all the segment `results` once and run prompts inference multiple times without running inference multiple times.
|
||||
|
||||
=== "Prompt inference"
|
||||
|
||||
```python
|
||||
from ultralytics.models.fastsam import FastSAMPredictor
|
||||
|
||||
# Create FastSAMPredictor
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", model="FastSAM-s.pt", save=False, imgsz=1024)
|
||||
predictor = FastSAMPredictor(overrides=overrides)
|
||||
|
||||
# Segment everything
|
||||
everything_results = predictor("ultralytics/assets/bus.jpg")
|
||||
|
||||
# Prompt inference
|
||||
bbox_results = predictor.prompt(everything_results, bboxes=[[200, 200, 300, 300]])
|
||||
point_results = predictor.prompt(everything_results, points=[200, 200])
|
||||
text_results = predictor.prompt(everything_results, texts="a photo of a dog")
|
||||
```
|
||||
|
||||
!!! Note
|
||||
|
||||
All the returned `results` in above examples are [Results](../modes/predict.md#working-with-results) object which allows access predicted masks and source image easily.
|
||||
|
|
@ -270,7 +285,6 @@ To use FastSAM for inference in Python, you can follow the example below:
|
|||
|
||||
```python
|
||||
from ultralytics import FastSAM
|
||||
from ultralytics.models.fastsam import FastSAMPrompt
|
||||
|
||||
# Define an inference source
|
||||
source = "path/to/bus.jpg"
|
||||
|
|
@ -281,21 +295,17 @@ model = FastSAM("FastSAM-s.pt") # or FastSAM-x.pt
|
|||
# Run inference on an image
|
||||
everything_results = model(source, device="cpu", retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
|
||||
|
||||
# Prepare a Prompt Process object
|
||||
prompt_process = FastSAMPrompt(source, everything_results, device="cpu")
|
||||
# Run inference with bboxes prompt
|
||||
results = model(source, bboxes=[439, 437, 524, 709])
|
||||
|
||||
# Everything prompt
|
||||
ann = prompt_process.everything_prompt()
|
||||
# Run inference with points prompt
|
||||
results = model(source, points=[[200, 200]], labels=[1])
|
||||
|
||||
# Bounding box prompt
|
||||
ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
|
||||
# Run inference with texts prompt
|
||||
results = model(source, texts="a photo of a dog")
|
||||
|
||||
# Text prompt
|
||||
ann = prompt_process.text_prompt(text="a photo of a dog")
|
||||
|
||||
# Point prompt
|
||||
ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
|
||||
prompt_process.plot(annotations=ann, output="./")
|
||||
# Run inference with bboxes and points and texts prompt at the same time
|
||||
results = model(source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog")
|
||||
```
|
||||
|
||||
For more details on inference methods, check the [Predict Usage](#predict-usage) section of the documentation.
|
||||
|
|
|
|||
|
|
@ -1,16 +0,0 @@
|
|||
---
|
||||
description: Explore the FastSAM prompt module for image annotation and visualization in Ultralytics, detailed with class methods and attributes.
|
||||
keywords: Ultralytics, FastSAM, image annotation, image visualization, FastSAMPrompt, YOLO, python script
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/models/fastsam/prompt.py`
|
||||
|
||||
!!! Note
|
||||
|
||||
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/prompt.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/models/fastsam/prompt.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/models/fastsam/prompt.py) 🛠️. Thank you 🙏!
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.models.fastsam.prompt.FastSAMPrompt
|
||||
|
||||
<br><br>
|
||||
|
|
@ -486,7 +486,6 @@ nav:
|
|||
- fastsam:
|
||||
- model: reference/models/fastsam/model.md
|
||||
- predict: reference/models/fastsam/predict.md
|
||||
- prompt: reference/models/fastsam/prompt.md
|
||||
- utils: reference/models/fastsam/utils.md
|
||||
- val: reference/models/fastsam/val.md
|
||||
- nas:
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ def test_fastsam(task="segment", model=WEIGHTS_DIR / "FastSAM-s.pt", data="coco8
|
|||
run(f"yolo segment predict model={model} source={source} imgsz=32 save save_crop save_txt")
|
||||
|
||||
from ultralytics import FastSAM
|
||||
from ultralytics.models.fastsam import FastSAMPrompt
|
||||
from ultralytics.models.sam import Predictor
|
||||
|
||||
# Create a FastSAM model
|
||||
|
|
@ -81,21 +80,10 @@ def test_fastsam(task="segment", model=WEIGHTS_DIR / "FastSAM-s.pt", data="coco8
|
|||
# Remove small regions
|
||||
new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20)
|
||||
|
||||
# Everything prompt
|
||||
prompt_process = FastSAMPrompt(s, everything_results, device="cpu")
|
||||
ann = prompt_process.everything_prompt()
|
||||
|
||||
# Bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
|
||||
ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])
|
||||
|
||||
# Text prompt
|
||||
ann = prompt_process.text_prompt(text="a photo of a dog")
|
||||
|
||||
# Point prompt
|
||||
# Points default [[0,0]] [[x1,y1],[x2,y2]]
|
||||
# Point_label default [0] [1,0] 0:background, 1:foreground
|
||||
ann = prompt_process.point_prompt(points=[[200, 200]], pointlabel=[1])
|
||||
prompt_process.plot(annotations=ann, output="./")
|
||||
# Run inference with bboxes and points and texts prompt at the same time
|
||||
results = sam_model(
|
||||
source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog"
|
||||
)
|
||||
|
||||
|
||||
def test_mobilesam():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.68"
|
||||
__version__ = "8.2.69"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .prompt import FastSAMPrompt
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMPrompt", "FastSAMValidator"
|
||||
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
|
||||
|
|
|
|||
|
|
@ -28,6 +28,24 @@ class FastSAM(Model):
|
|||
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs):
|
||||
"""
|
||||
Performs segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source (str): Path to the image or video file, or a PIL.Image object, or a numpy.ndarray object.
|
||||
stream (bool, optional): If True, enables real-time streaming. Defaults to False.
|
||||
bboxes (list, optional): List of bounding box coordinates for prompted segmentation. Defaults to None.
|
||||
points (list, optional): List of points for prompted segmentation. Defaults to None.
|
||||
labels (list, optional): List of labels for prompted segmentation. Defaults to None.
|
||||
texts (list, optional): List of texts for prompted segmentation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
"""
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, checks
|
||||
from ultralytics.utils.metrics import box_iou
|
||||
from ultralytics.utils.ops import scale_masks
|
||||
|
||||
from .utils import adjust_bboxes_to_image_border
|
||||
|
||||
|
|
@ -17,8 +20,16 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
class segmentation.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.prompts = {}
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Applies box postprocess for FastSAM predictions."""
|
||||
bboxes = self.prompts.pop("bboxes", None)
|
||||
points = self.prompts.pop("points", None)
|
||||
labels = self.prompts.pop("labels", None)
|
||||
texts = self.prompts.pop("texts", None)
|
||||
results = super().postprocess(preds, img, orig_imgs)
|
||||
for result in results:
|
||||
full_box = torch.tensor(
|
||||
|
|
@ -28,4 +39,107 @@ class FastSAMPredictor(SegmentationPredictor):
|
|||
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
||||
if idx.numel() != 0:
|
||||
result.boxes.xyxy[idx] = full_box
|
||||
|
||||
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
|
||||
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
||||
"""
|
||||
Internal function for image segmentation inference based on cues like bounding boxes, points, and masks.
|
||||
Leverages SAM's specialized architecture for prompt-based, real-time segmentation.
|
||||
|
||||
Args:
|
||||
results (Results | List[Results]): The original inference results from FastSAM models without any prompts.
|
||||
bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
texts (str | List[str], optional): Textual prompts, a list contains string objects.
|
||||
|
||||
Returns:
|
||||
(List[Results]): The output results determined by prompts.
|
||||
"""
|
||||
if bboxes is None and points is None and texts is None:
|
||||
return results
|
||||
prompt_results = []
|
||||
if not isinstance(results, list):
|
||||
results = [results]
|
||||
for result in results:
|
||||
masks = result.masks.data
|
||||
if masks.shape[1:] != result.orig_shape:
|
||||
masks = scale_masks(masks[None], result.orig_shape)[0]
|
||||
# bboxes prompt
|
||||
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
||||
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
||||
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert len(labels) == len(
|
||||
points
|
||||
), f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}"
|
||||
point_idx = (
|
||||
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
||||
if labels.sum() == 0 # all negative points
|
||||
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
)
|
||||
for p, l in zip(points, labels):
|
||||
point_idx[torch.nonzero(masks[:, p[1], p[0]], as_tuple=True)[0]] = True if l else False
|
||||
idx |= point_idx
|
||||
if texts is not None:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
crop_ims, filter_idx = [], []
|
||||
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
||||
x1, y1, x2, y2 = [int(x) for x in b]
|
||||
if masks[i].sum() <= 100:
|
||||
filter_idx.append(i)
|
||||
continue
|
||||
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
||||
similarity = self._clip_inference(crop_ims, texts)
|
||||
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
||||
if len(filter_idx):
|
||||
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
||||
idx[text_idx] = True
|
||||
|
||||
prompt_results.append(result[idx])
|
||||
|
||||
return prompt_results
|
||||
|
||||
def _clip_inference(self, images, texts):
|
||||
"""
|
||||
CLIP Inference process.
|
||||
|
||||
Args:
|
||||
images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order.
|
||||
texts (List[str]): A list of prompt texts and each of them should be string object.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The similarity between given images and texts.
|
||||
"""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
||||
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
||||
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
||||
tokenized_text = clip.tokenize(texts).to(self.device)
|
||||
image_features = self.clip_model.encode_image(images)
|
||||
text_features = self.clip_model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
||||
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts in advance."""
|
||||
self.prompts = prompts
|
||||
|
|
|
|||
|
|
@ -1,352 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from ultralytics.utils import TQDM, checks
|
||||
|
||||
|
||||
class FastSAMPrompt:
|
||||
"""
|
||||
Fast Segment Anything Model class for image annotation and visualization.
|
||||
|
||||
Attributes:
|
||||
device (str): Computing device ('cuda' or 'cpu').
|
||||
results: Object detection or segmentation results.
|
||||
source: Source image or image path.
|
||||
clip: CLIP model for linear assignment.
|
||||
"""
|
||||
|
||||
def __init__(self, source, results, device="cuda") -> None:
|
||||
"""Initializes FastSAMPrompt with given source, results and device, and assigns clip for linear assignment."""
|
||||
if isinstance(source, (str, Path)) and os.path.isdir(source):
|
||||
raise ValueError("FastSAM only accepts image paths and PIL Image sources, not directories.")
|
||||
self.device = device
|
||||
self.results = results
|
||||
self.source = source
|
||||
|
||||
# Import and assign clip
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
self.clip = clip
|
||||
|
||||
@staticmethod
|
||||
def _segment_image(image, bbox):
|
||||
"""Segments the given image according to the provided bounding box coordinates."""
|
||||
image_array = np.array(image)
|
||||
segmented_image_array = np.zeros_like(image_array)
|
||||
x1, y1, x2, y2 = bbox
|
||||
segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
|
||||
segmented_image = Image.fromarray(segmented_image_array)
|
||||
black_image = Image.new("RGB", image.size, (255, 255, 255))
|
||||
# transparency_mask = np.zeros_like((), dtype=np.uint8)
|
||||
transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8)
|
||||
transparency_mask[y1:y2, x1:x2] = 255
|
||||
transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
|
||||
black_image.paste(segmented_image, mask=transparency_mask_image)
|
||||
return black_image
|
||||
|
||||
@staticmethod
|
||||
def _format_results(result, filter=0):
|
||||
"""Formats detection results into list of annotations each containing ID, segmentation, bounding box, score and
|
||||
area.
|
||||
"""
|
||||
annotations = []
|
||||
n = len(result.masks.data) if result.masks is not None else 0
|
||||
for i in range(n):
|
||||
mask = result.masks.data[i] == 1.0
|
||||
if torch.sum(mask) >= filter:
|
||||
annotation = {
|
||||
"id": i,
|
||||
"segmentation": mask.cpu().numpy(),
|
||||
"bbox": result.boxes.data[i],
|
||||
"score": result.boxes.conf[i],
|
||||
}
|
||||
annotation["area"] = annotation["segmentation"].sum()
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
@staticmethod
|
||||
def _get_bbox_from_mask(mask):
|
||||
"""Applies morphological transformations to the mask, displays it, and if with_contours is True, draws
|
||||
contours.
|
||||
"""
|
||||
mask = mask.astype(np.uint8)
|
||||
contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
x1, y1, w, h = cv2.boundingRect(contours[0])
|
||||
x2, y2 = x1 + w, y1 + h
|
||||
if len(contours) > 1:
|
||||
for b in contours:
|
||||
x_t, y_t, w_t, h_t = cv2.boundingRect(b)
|
||||
x1 = min(x1, x_t)
|
||||
y1 = min(y1, y_t)
|
||||
x2 = max(x2, x_t + w_t)
|
||||
y2 = max(y2, y_t + h_t)
|
||||
return [x1, y1, x2, y2]
|
||||
|
||||
def plot(
|
||||
self,
|
||||
annotations,
|
||||
output,
|
||||
bbox=None,
|
||||
points=None,
|
||||
point_label=None,
|
||||
mask_random_color=True,
|
||||
better_quality=True,
|
||||
retina=False,
|
||||
with_contours=True,
|
||||
):
|
||||
"""
|
||||
Plots annotations, bounding boxes, and points on images and saves the output.
|
||||
|
||||
Args:
|
||||
annotations (list): Annotations to be plotted.
|
||||
output (str or Path): Output directory for saving the plots.
|
||||
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
||||
points (list, optional): Points to be plotted. Defaults to None.
|
||||
point_label (list, optional): Labels for the points. Defaults to None.
|
||||
mask_random_color (bool, optional): Whether to use random color for masks. Defaults to True.
|
||||
better_quality (bool, optional): Whether to apply morphological transformations for better mask quality.
|
||||
Defaults to True.
|
||||
retina (bool, optional): Whether to use retina mask. Defaults to False.
|
||||
with_contours (bool, optional): Whether to plot contours. Defaults to True.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
pbar = TQDM(annotations, total=len(annotations))
|
||||
for ann in pbar:
|
||||
result_name = os.path.basename(ann.path)
|
||||
image = ann.orig_img[..., ::-1] # BGR to RGB
|
||||
original_h, original_w = ann.orig_shape
|
||||
# For macOS only
|
||||
# plt.switch_backend('TkAgg')
|
||||
plt.figure(figsize=(original_w / 100, original_h / 100))
|
||||
# Add subplot with no margin.
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.gca().xaxis.set_major_locator(plt.NullLocator())
|
||||
plt.gca().yaxis.set_major_locator(plt.NullLocator())
|
||||
plt.imshow(image)
|
||||
|
||||
if ann.masks is not None:
|
||||
masks = ann.masks.data
|
||||
if better_quality:
|
||||
if isinstance(masks[0], torch.Tensor):
|
||||
masks = np.array(masks.cpu())
|
||||
for i, mask in enumerate(masks):
|
||||
mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
|
||||
masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
|
||||
|
||||
self.fast_show_mask(
|
||||
masks,
|
||||
plt.gca(),
|
||||
random_color=mask_random_color,
|
||||
bbox=bbox,
|
||||
points=points,
|
||||
pointlabel=point_label,
|
||||
retinamask=retina,
|
||||
target_height=original_h,
|
||||
target_width=original_w,
|
||||
)
|
||||
|
||||
if with_contours:
|
||||
contour_all = []
|
||||
temp = np.zeros((original_h, original_w, 1))
|
||||
for i, mask in enumerate(masks):
|
||||
mask = mask.astype(np.uint8)
|
||||
if not retina:
|
||||
mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contour_all.extend(iter(contours))
|
||||
cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
|
||||
color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
|
||||
contour_mask = temp / 255 * color.reshape(1, 1, -1)
|
||||
plt.imshow(contour_mask)
|
||||
|
||||
# Save the figure
|
||||
save_path = Path(output) / result_name
|
||||
save_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
plt.axis("off")
|
||||
plt.savefig(save_path, bbox_inches="tight", pad_inches=0, transparent=True)
|
||||
plt.close()
|
||||
pbar.set_description(f"Saving {result_name} to {save_path}")
|
||||
|
||||
@staticmethod
|
||||
def fast_show_mask(
|
||||
annotation,
|
||||
ax,
|
||||
random_color=False,
|
||||
bbox=None,
|
||||
points=None,
|
||||
pointlabel=None,
|
||||
retinamask=True,
|
||||
target_height=960,
|
||||
target_width=960,
|
||||
):
|
||||
"""
|
||||
Quickly shows the mask annotations on the given matplotlib axis.
|
||||
|
||||
Args:
|
||||
annotation (array-like): Mask annotation.
|
||||
ax (matplotlib.axes.Axes): Matplotlib axis.
|
||||
random_color (bool, optional): Whether to use random color for masks. Defaults to False.
|
||||
bbox (list, optional): Bounding box coordinates [x1, y1, x2, y2]. Defaults to None.
|
||||
points (list, optional): Points to be plotted. Defaults to None.
|
||||
pointlabel (list, optional): Labels for the points. Defaults to None.
|
||||
retinamask (bool, optional): Whether to use retina mask. Defaults to True.
|
||||
target_height (int, optional): Target height for resizing. Defaults to 960.
|
||||
target_width (int, optional): Target width for resizing. Defaults to 960.
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
n, h, w = annotation.shape # batch, height, width
|
||||
|
||||
areas = np.sum(annotation, axis=(1, 2))
|
||||
annotation = annotation[np.argsort(areas)]
|
||||
|
||||
index = (annotation != 0).argmax(axis=0)
|
||||
if random_color:
|
||||
color = np.random.random((n, 1, 1, 3))
|
||||
else:
|
||||
color = np.ones((n, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 1.0])
|
||||
transparency = np.ones((n, 1, 1, 1)) * 0.6
|
||||
visual = np.concatenate([color, transparency], axis=-1)
|
||||
mask_image = np.expand_dims(annotation, -1) * visual
|
||||
|
||||
show = np.zeros((h, w, 4))
|
||||
h_indices, w_indices = np.meshgrid(np.arange(h), np.arange(w), indexing="ij")
|
||||
indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
|
||||
|
||||
show[h_indices, w_indices, :] = mask_image[indices]
|
||||
if bbox is not None:
|
||||
x1, y1, x2, y2 = bbox
|
||||
ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1))
|
||||
# Draw point
|
||||
if points is not None:
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
|
||||
s=20,
|
||||
c="y",
|
||||
)
|
||||
plt.scatter(
|
||||
[point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
[point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
|
||||
s=20,
|
||||
c="m",
|
||||
)
|
||||
|
||||
if not retinamask:
|
||||
show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
|
||||
ax.imshow(show)
|
||||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
|
||||
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
stacked_images = torch.stack(preprocessed_images)
|
||||
image_features = model.encode_image(stacked_images)
|
||||
text_features = model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
probs = 100.0 * image_features @ text_features.T
|
||||
return probs[:, 0].softmax(dim=0)
|
||||
|
||||
def _crop_image(self, format_results):
|
||||
"""Crops an image based on provided annotation format and returns cropped images and related data."""
|
||||
image = Image.fromarray(cv2.cvtColor(self.results[0].orig_img, cv2.COLOR_BGR2RGB))
|
||||
ori_w, ori_h = image.size
|
||||
annotations = format_results
|
||||
mask_h, mask_w = annotations[0]["segmentation"].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_images = []
|
||||
filter_id = []
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask["segmentation"]) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
||||
cropped_images.append(self._segment_image(image, bbox)) # save cropped image
|
||||
|
||||
return cropped_images, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||
if self.results[0].masks is not None:
|
||||
assert bbox[2] != 0 and bbox[3] != 0, "Bounding box width and height should not be zero"
|
||||
masks = self.results[0].masks.data
|
||||
target_height, target_width = self.results[0].orig_shape
|
||||
h = masks.shape[1]
|
||||
w = masks.shape[2]
|
||||
if h != target_height or w != target_width:
|
||||
bbox = [
|
||||
int(bbox[0] * w / target_width),
|
||||
int(bbox[1] * h / target_height),
|
||||
int(bbox[2] * w / target_width),
|
||||
int(bbox[3] * h / target_height),
|
||||
]
|
||||
bbox[0] = max(round(bbox[0]), 0)
|
||||
bbox[1] = max(round(bbox[1]), 0)
|
||||
bbox[2] = min(round(bbox[2]), w)
|
||||
bbox[3] = min(round(bbox[3]), h)
|
||||
|
||||
# IoUs = torch.zeros(len(masks), dtype=torch.float32)
|
||||
bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
|
||||
|
||||
masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
|
||||
orig_masks_area = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_area + orig_masks_area - masks_area
|
||||
iou = masks_area / union
|
||||
max_iou_index = torch.argmax(iou)
|
||||
|
||||
self.results[0].masks.data = torch.tensor(np.array([masks[max_iou_index].cpu().numpy()]))
|
||||
return self.results
|
||||
|
||||
def point_prompt(self, points, pointlabel): # numpy
|
||||
"""Adjusts points on detected masks based on user input and returns the modified results."""
|
||||
if self.results[0].masks is not None:
|
||||
masks = self._format_results(self.results[0], 0)
|
||||
target_height, target_width = self.results[0].orig_shape
|
||||
h = masks[0]["segmentation"].shape[0]
|
||||
w = masks[0]["segmentation"].shape[1]
|
||||
if h != target_height or w != target_width:
|
||||
points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points]
|
||||
onemask = np.zeros((h, w))
|
||||
for annotation in masks:
|
||||
mask = annotation["segmentation"] if isinstance(annotation, dict) else annotation
|
||||
for i, point in enumerate(points):
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
|
||||
onemask += mask
|
||||
if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
|
||||
onemask -= mask
|
||||
onemask = onemask >= 1
|
||||
self.results[0].masks.data = torch.tensor(np.array([onemask]))
|
||||
return self.results
|
||||
|
||||
def text_prompt(self, text, clip_download_root=None):
|
||||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_images, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = self.clip.load("ViT-B/32", download_root=clip_download_root, device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
|
||||
max_idx = torch.argmax(scores)
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
||||
return self.results
|
||||
|
||||
def everything_prompt(self):
|
||||
"""Returns the processed results from the previous methods in the class."""
|
||||
return self.results
|
||||
|
|
@ -363,7 +363,7 @@ def scale_image(masks, im0_shape, ratio_pad=None):
|
|||
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): The masks that are being returned.
|
||||
masks (np.ndarray): The masks that are being returned with shape [h, w, num].
|
||||
"""
|
||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||
im1_shape = masks.shape
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue