From edca88d1b3c9d387c2c316ec5e7817d2ee839014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nguy=E1=BB=85n=20Anh=20B=C3=ACnh?= Date: Tue, 16 Jul 2024 06:24:26 +0700 Subject: [PATCH] `ultralytics 8.2.58` FastSAM code refactor (#14450) Co-authored-by: Glenn Jocher --- ultralytics/__init__.py | 2 +- ultralytics/models/fastsam/prompt.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 4f65b69e..9d9b62aa 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.57" +__version__ = "8.2.58" import os diff --git a/ultralytics/models/fastsam/prompt.py b/ultralytics/models/fastsam/prompt.py index 4add9fbb..9738252e 100644 --- a/ultralytics/models/fastsam/prompt.py +++ b/ultralytics/models/fastsam/prompt.py @@ -7,6 +7,7 @@ import cv2 import numpy as np import torch from PIL import Image +from torch import Tensor from ultralytics.utils import TQDM, checks @@ -249,7 +250,7 @@ class FastSAMPrompt: ax.imshow(show) @torch.no_grad() - def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: + def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor: """Processes images and text with a model, calculates similarity, and returns softmax score.""" preprocessed_images = [preprocess(image).to(device) for image in elements] tokenized_text = self.clip.tokenize([search_text]).to(device) @@ -269,19 +270,16 @@ class FastSAMPrompt: mask_h, mask_w = annotations[0]["segmentation"].shape if ori_w != mask_w or ori_h != mask_h: image = image.resize((mask_w, mask_h)) - cropped_boxes = [] cropped_images = [] - not_crop = [] filter_id = [] for _, mask in enumerate(annotations): if np.sum(mask["segmentation"]) <= 100: filter_id.append(_) continue bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask - cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image - cropped_images.append(bbox) # save cropped image bbox + cropped_images.append(self._segment_image(image, bbox)) # save cropped image - return cropped_boxes, cropped_images, not_crop, filter_id, annotations + return cropped_images, filter_id, annotations def box_prompt(self, bbox): """Modifies the bounding box properties and calculates IoU between masks and bounding box.""" @@ -341,11 +339,10 @@ class FastSAMPrompt: """Processes a text prompt, applies it to existing results and returns the updated results.""" if self.results[0].masks is not None: format_results = self._format_results(self.results[0], 0) - cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) + cropped_images, filter_id, annotations = self._crop_image(format_results) clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) - scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) - max_idx = scores.argsort() - max_idx = max_idx[-1] + scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device) + max_idx = torch.argmax(scores) max_idx += sum(np.array(filter_id) <= int(max_idx)) self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]])) return self.results