ultralytics 8.2.58 FastSAM code refactor (#14450)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Nguyễn Anh Bình 2024-07-16 06:24:26 +07:00 committed by GitHub
parent c2f9a12cb4
commit edca88d1b3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 11 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.57" __version__ = "8.2.58"
import os import os

View file

@ -7,6 +7,7 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from torch import Tensor
from ultralytics.utils import TQDM, checks from ultralytics.utils import TQDM, checks
@ -249,7 +250,7 @@ class FastSAMPrompt:
ax.imshow(show) ax.imshow(show)
@torch.no_grad() @torch.no_grad()
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
"""Processes images and text with a model, calculates similarity, and returns softmax score.""" """Processes images and text with a model, calculates similarity, and returns softmax score."""
preprocessed_images = [preprocess(image).to(device) for image in elements] preprocessed_images = [preprocess(image).to(device) for image in elements]
tokenized_text = self.clip.tokenize([search_text]).to(device) tokenized_text = self.clip.tokenize([search_text]).to(device)
@ -269,19 +270,16 @@ class FastSAMPrompt:
mask_h, mask_w = annotations[0]["segmentation"].shape mask_h, mask_w = annotations[0]["segmentation"].shape
if ori_w != mask_w or ori_h != mask_h: if ori_w != mask_w or ori_h != mask_h:
image = image.resize((mask_w, mask_h)) image = image.resize((mask_w, mask_h))
cropped_boxes = []
cropped_images = [] cropped_images = []
not_crop = []
filter_id = [] filter_id = []
for _, mask in enumerate(annotations): for _, mask in enumerate(annotations):
if np.sum(mask["segmentation"]) <= 100: if np.sum(mask["segmentation"]) <= 100:
filter_id.append(_) filter_id.append(_)
continue continue
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image cropped_images.append(self._segment_image(image, bbox)) # save cropped image
cropped_images.append(bbox) # save cropped image bbox
return cropped_boxes, cropped_images, not_crop, filter_id, annotations return cropped_images, filter_id, annotations
def box_prompt(self, bbox): def box_prompt(self, bbox):
"""Modifies the bounding box properties and calculates IoU between masks and bounding box.""" """Modifies the bounding box properties and calculates IoU between masks and bounding box."""
@ -341,11 +339,10 @@ class FastSAMPrompt:
"""Processes a text prompt, applies it to existing results and returns the updated results.""" """Processes a text prompt, applies it to existing results and returns the updated results."""
if self.results[0].masks is not None: if self.results[0].masks is not None:
format_results = self._format_results(self.results[0], 0) format_results = self._format_results(self.results[0], 0)
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) cropped_images, filter_id, annotations = self._crop_image(format_results)
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device) clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
max_idx = scores.argsort() max_idx = torch.argmax(scores)
max_idx = max_idx[-1]
max_idx += sum(np.array(filter_id) <= int(max_idx)) max_idx += sum(np.array(filter_id) <= int(max_idx))
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]])) self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
return self.results return self.results