Explorer Cleanup (#7364)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-07 01:31:41 +01:00 committed by GitHub
parent aca8eb1fd4
commit ed73c0fedc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 585 additions and 63 deletions

View file

@ -40,7 +40,7 @@ class ExplorerDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
transforms = Format(
return Format(
bbox_format='xyxy',
normalize=False,
return_mask=self.use_segments,
@ -49,7 +49,6 @@ class ExplorerDataset(YOLODataset):
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
return transforms
class Explorer:
@ -161,8 +160,7 @@ class Explorer:
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
query = self.table.search(embeds).limit(limit).to_arrow()
return query
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self, query, return_type='pandas'):
"""
@ -223,8 +221,7 @@ class Explorer:
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
"""
@ -276,8 +273,7 @@ class Explorer:
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
img = Image.fromarray(img)
return img
return Image.fromarray(img)
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
@ -331,7 +327,6 @@ class Explorer:
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
@ -373,8 +368,7 @@ class Explorer:
buffer.seek(0)
# Use Pillow to open the image from the buffer
image = Image.open(buffer)
return image
return Image.open(buffer)
def _check_imgs_or_idxs(self, img, idx):
if img is None and idx is None:
@ -385,8 +379,7 @@ class Explorer:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = img if isinstance(img, list) else [img]
return img
return img if isinstance(img, list) else [img]
def visualize(self, result):
"""

View file

@ -1,4 +1,3 @@
from pathlib import Path
from typing import List
import cv2
@ -94,10 +93,12 @@ def plot_similar_images(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
fname = 'temp_exp_grid.jpg'
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
max_subplots=len(images)).join()
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
Path(fname).unlink()
return img_rgb
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)

View file

@ -736,16 +736,19 @@ class TryExcept(contextlib.ContextDecorator):
def threaded(func):
"""
Multi-threads a target function and returns thread.
Multi-threads a target function by default and returns the thread or function result.
Use as @threaded decorator.
Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
"""
def wrapper(*args, **kwargs):
"""Multi-threads a given function and returns the thread."""
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
if kwargs.pop('threaded', True): # run in thread
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start()
return thread
else:
return func(*args, **kwargs)
return wrapper

View file

@ -125,7 +125,7 @@ class Annotator:
if rotated:
p1 = [int(b) for b in box[0]]
# NOTE: cv2-version polylines needs np.asarray type.
cv2.polylines(self.im, [np.asarray(box, dtype=np.int)], True, color, self.lw)
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
else:
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
@ -580,7 +580,8 @@ def plot_images(images,
fname='images.jpg',
names=None,
on_plot=None,
max_subplots=16):
max_subplots=16,
save=True):
"""Plot image grid with labels."""
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
@ -596,7 +597,6 @@ def plot_images(images,
batch_idx = batch_idx.cpu().numpy()
max_size = 1920 # max image size
max_subplots = max_subplots # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
@ -605,12 +605,9 @@ def plot_images(images,
# Build Image
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
for i, im in enumerate(images):
if i == max_subplots: # if last batch has fewer images than we expect
break
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
im = im.transpose(1, 2, 0)
mosaic[y:y + h, x:x + w, :] = im
mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
# Resize (optional)
scale = max_size / ns / max(h, w)
@ -622,7 +619,7 @@ def plot_images(images,
# Annotate
fs = int((h + w) * ns * 0.01) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
for i in range(i + 1):
for i in range(bs):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
if paths:
@ -699,9 +696,12 @@ def plot_images(images,
with contextlib.suppress(Exception):
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
annotator.fromarray(im)
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
if save:
annotator.im.save(fname) # save
if on_plot:
on_plot(fname)
else:
return np.asarray(annotator.im)
@plt_settings()