ultralytics 8.2.58 FastSAM code refactor (#14450)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c2f9a12cb4
commit
edca88d1b3
2 changed files with 8 additions and 11 deletions
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.2.57"
|
__version__ = "8.2.58"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from ultralytics.utils import TQDM, checks
|
from ultralytics.utils import TQDM, checks
|
||||||
|
|
||||||
|
|
@ -249,7 +250,7 @@ class FastSAMPrompt:
|
||||||
ax.imshow(show)
|
ax.imshow(show)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def retrieve(self, model, preprocess, elements, search_text: str, device) -> int:
|
def retrieve(self, model, preprocess, elements, search_text: str, device) -> Tensor:
|
||||||
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
"""Processes images and text with a model, calculates similarity, and returns softmax score."""
|
||||||
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
preprocessed_images = [preprocess(image).to(device) for image in elements]
|
||||||
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
tokenized_text = self.clip.tokenize([search_text]).to(device)
|
||||||
|
|
@ -269,19 +270,16 @@ class FastSAMPrompt:
|
||||||
mask_h, mask_w = annotations[0]["segmentation"].shape
|
mask_h, mask_w = annotations[0]["segmentation"].shape
|
||||||
if ori_w != mask_w or ori_h != mask_h:
|
if ori_w != mask_w or ori_h != mask_h:
|
||||||
image = image.resize((mask_w, mask_h))
|
image = image.resize((mask_w, mask_h))
|
||||||
cropped_boxes = []
|
|
||||||
cropped_images = []
|
cropped_images = []
|
||||||
not_crop = []
|
|
||||||
filter_id = []
|
filter_id = []
|
||||||
for _, mask in enumerate(annotations):
|
for _, mask in enumerate(annotations):
|
||||||
if np.sum(mask["segmentation"]) <= 100:
|
if np.sum(mask["segmentation"]) <= 100:
|
||||||
filter_id.append(_)
|
filter_id.append(_)
|
||||||
continue
|
continue
|
||||||
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
bbox = self._get_bbox_from_mask(mask["segmentation"]) # bbox from mask
|
||||||
cropped_boxes.append(self._segment_image(image, bbox)) # save cropped image
|
cropped_images.append(self._segment_image(image, bbox)) # save cropped image
|
||||||
cropped_images.append(bbox) # save cropped image bbox
|
|
||||||
|
|
||||||
return cropped_boxes, cropped_images, not_crop, filter_id, annotations
|
return cropped_images, filter_id, annotations
|
||||||
|
|
||||||
def box_prompt(self, bbox):
|
def box_prompt(self, bbox):
|
||||||
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
"""Modifies the bounding box properties and calculates IoU between masks and bounding box."""
|
||||||
|
|
@ -341,11 +339,10 @@ class FastSAMPrompt:
|
||||||
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
"""Processes a text prompt, applies it to existing results and returns the updated results."""
|
||||||
if self.results[0].masks is not None:
|
if self.results[0].masks is not None:
|
||||||
format_results = self._format_results(self.results[0], 0)
|
format_results = self._format_results(self.results[0], 0)
|
||||||
cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results)
|
cropped_images, filter_id, annotations = self._crop_image(format_results)
|
||||||
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
clip_model, preprocess = self.clip.load("ViT-B/32", device=self.device)
|
||||||
scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device)
|
scores = self.retrieve(clip_model, preprocess, cropped_images, text, device=self.device)
|
||||||
max_idx = scores.argsort()
|
max_idx = torch.argmax(scores)
|
||||||
max_idx = max_idx[-1]
|
|
||||||
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
max_idx += sum(np.array(filter_id) <= int(max_idx))
|
||||||
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
self.results[0].masks.data = torch.tensor(np.array([annotations[max_idx]["segmentation"]]))
|
||||||
return self.results
|
return self.results
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue