Explorer Cleanup (#7364)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-07 01:31:41 +01:00 committed by GitHub
parent aca8eb1fd4
commit ed73c0fedc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 585 additions and 63 deletions

View file

@ -40,7 +40,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
transforms = Format(
return Format(
bbox_format='xyxy',
normalize=False,
return_mask=self.use_segments,
@ -49,7 +49,6 @@ class ExplorerDataset(YOLODataset):
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
return transforms
class Explorer:
@ -161,8 +160,7 @@ class Explorer:
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
query = self.table.search(embeds).limit(limit).to_arrow()
return query
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self, query, return_type='pandas'):
"""
@ -223,8 +221,7 @@ class Explorer:
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
"""
@ -276,8 +273,7 @@ class Explorer:
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
@ -331,7 +327,6 @@ class Explorer:
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
@ -373,8 +368,7 @@ class Explorer:
buffer.seek(0)
# Use Pillow to open the image from the buffer
image = Image.open(buffer)
return image
return Image.open(buffer)
def _check_imgs_or_idxs(self, img, idx):
if img is None and idx is None:
@ -385,8 +379,7 @@ class Explorer:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = img if isinstance(img, list) else [img]
return img
return img if isinstance(img, list) else [img]
def visualize(self, result):
"""

View file

@ -1,4 +1,3 @@
from pathlib import Path
from typing import List
import cv2
@ -94,10 +93,12 @@ def plot_similar_images(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
fname = 'temp_exp_grid.jpg'
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
max_subplots=len(images)).join()
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
Path(fname).unlink()
return img_rgb
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)