ultralytics 8.2.58 FastSAM code refactor (#14450)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c2f9a12cb4
commit
edca88d1b3
2 changed files with 8 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.57"
|
||||
__version__ = "8.2.58"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
|
||||
from ultralytics.utils import TQDM, checks
|
||||
|
||||
|
|
@ -249,7 +250,7 @@ class FastSAMPrompt:
|
|||
ax.imshow(show)
|
||||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
|
||||
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||
|
|
@ -269,19 +270,16 @@ class FastSAMPrompt:
|
|||
mask_h, mask_w = annotations[0]["segmentation"].shape
|
||||
if ori_w != mask_w or ori_h != mask_h:
|
||||
image = image.resize((mask_w, mask_h))
|
||||
cropped_boxes = []
|
||||
cropped_images = []
|
||||
not_crop = []
|
||||
filter_id = []
|
||||
for _, mask in enumerate(annotations):
|
||||
if np.sum(mask["segmentation"]) <= 100:
|
||||
filter_id.append(_)
|
||||
continue
|
||||
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
||||
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
|
||||
cropped_images.append(bbox) # save cropped image bbox
|
||||
cropped_images.append(self._segment_image(image, bbox)) # save cropped image
|
||||
|
||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
||||
return cropped_images, filter_id, annotations
|
||||
|
||||
def box_prompt(self, bbox):
|
||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||
|
|
@ -341,11 +339,10 @@ class FastSAMPrompt:
|
|||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||
if self.results[0].masks is not None:
|
||||
format_results = self._format_results(self.results[0], 0)
|
||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
||||
cropped_images, filter_id, annotations = self._crop_image(format_results)
|
||||
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
||||
max_idx = scores.argsort()
|
||||
max_idx = max_idx[-1]
|
||||
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
|
||||
max_idx = torch.argmax(scores)
|
||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
||||
return self.results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue