ultralytics 8.2.58 FastSAM code refactor (#14450)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Nguyễn Anh Bình 2024-07-16 06:24:26 +07:00 committed by GitHub
parent c2f9a12cb4
commit edca88d1b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 11 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.57"
__version__ = "8.2.58"
import os

View file

@ -7,6 +7,7 @@ import cv2
import numpy as np
import torch
from PIL import Image
from torch import Tensor
from ultralytics.utils import TQDM, checks
@ -249,7 +250,7 @@ class FastSAMPrompt:
ax.imshow(show)
@torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = self.clip.tokenize([search_text]).to(device)
@ -269,19 +270,16 @@ class FastSAMPrompt:
mask_h, mask_w = annotations[0]["segmentation"].shape
if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = []
not_crop = []
filter_id = []
for _, mask in enumerate(annotations):
if np.sum(mask["segmentation"]) <= 100:
filter_id.append(_)
continue
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
cropped_images.append(bbox) # save cropped image bbox
cropped_images.append(self._segment_image(image, bbox)) # save cropped image
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
return cropped_images, filter_id, annotations
def box_prompt(self, bbox):
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
@ -341,11 +339,10 @@ class FastSAMPrompt:
"""Processes a text prompt, applies it to existing results and returns the updated results."""
if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
cropped_images, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
max_idx = scores.argsort()
max_idx = max_idx[-1]
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
max_idx = torch.argmax(scores)
max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
return self.results