Explorer Cleanup (#7364)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
aca8eb1fd4
commit
ed73c0fedc
15 changed files with 585 additions and 63 deletions
|
|
@ -40,7 +40,7 @@ class ExplorerDataset(YOLODataset):
|
|||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
transforms = Format(
|
||||
return Format(
|
||||
bbox_format='xyxy',
|
||||
normalize=False,
|
||||
return_mask=self.use_segments,
|
||||
|
|
@ -49,7 +49,6 @@ class ExplorerDataset(YOLODataset):
|
|||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
return transforms
|
||||
|
||||
|
||||
class Explorer:
|
||||
|
|
@ -161,8 +160,7 @@ class Explorer:
|
|||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
query = self.table.search(embeds).limit(limit).to_arrow()
|
||||
return query
|
||||
return self.table.search(embeds).limit(limit).to_arrow()
|
||||
|
||||
def sql_query(self, query, return_type='pandas'):
|
||||
"""
|
||||
|
|
@ -223,8 +221,7 @@ class Explorer:
|
|||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
|
||||
"""
|
||||
|
|
@ -276,8 +273,7 @@ class Explorer:
|
|||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
"""
|
||||
|
|
@ -331,7 +327,6 @@ class Explorer:
|
|||
|
||||
sim_table.add(_yield_sim_idx())
|
||||
self.sim_index = sim_table
|
||||
|
||||
return sim_table.to_pandas()
|
||||
|
||||
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
|
|
@ -373,8 +368,7 @@ class Explorer:
|
|||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
image = Image.open(buffer)
|
||||
return image
|
||||
return Image.open(buffer)
|
||||
|
||||
def _check_imgs_or_idxs(self, img, idx):
|
||||
if img is None and idx is None:
|
||||
|
|
@ -385,8 +379,7 @@ class Explorer:
|
|||
idx = idx if isinstance(idx, list) else [idx]
|
||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
|
||||
|
||||
img = img if isinstance(img, list) else [img]
|
||||
return img
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
|
|
@ -94,10 +93,12 @@ def plot_similar_images(similar_set, plot_labels=True):
|
|||
batch_idx = np.concatenate(batch_idx, axis=0)
|
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
||||
|
||||
fname = 'temp_exp_grid.jpg'
|
||||
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
|
||||
max_subplots=len(images)).join()
|
||||
img = cv2.imread(fname, cv2.IMREAD_COLOR)
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
Path(fname).unlink()
|
||||
return img_rgb
|
||||
return plot_images(imgs,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes=boxes,
|
||||
masks=masks,
|
||||
kpts=kpts,
|
||||
max_subplots=len(images),
|
||||
save=False,
|
||||
threaded=False)
|
||||
|
|
|
|||
|
|
@ -736,16 +736,19 @@ class TryExcept(contextlib.ContextDecorator):
|
|||
|
||||
def threaded(func):
|
||||
"""
|
||||
Multi-threads a target function and returns thread.
|
||||
Multi-threads a target function by default and returns the thread or function result.
|
||||
|
||||
Use as @threaded decorator.
|
||||
Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
"""Multi-threads a given function and returns the thread."""
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
"""Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result."""
|
||||
if kwargs.pop('threaded', True): # run in thread
|
||||
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class Annotator:
|
|||
if rotated:
|
||||
p1 = [int(b) for b in box[0]]
|
||||
# NOTE: cv2-version polylines needs np.asarray type.
|
||||
cv2.polylines(self.im, [np.asarray(box, dtype=np.int)], True, color, self.lw)
|
||||
cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw)
|
||||
else:
|
||||
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
||||
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
||||
|
|
@ -580,7 +580,8 @@ def plot_images(images,
|
|||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None,
|
||||
max_subplots=16):
|
||||
max_subplots=16,
|
||||
save=True):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
|
|
@ -596,7 +597,6 @@ def plot_images(images,
|
|||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
max_subplots = max_subplots # max image subplots, i.e. 4x4
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
|
|
@ -605,12 +605,9 @@ def plot_images(images,
|
|||
|
||||
# Build Image
|
||||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
||||
for i, im in enumerate(images):
|
||||
if i == max_subplots: # if last batch has fewer images than we expect
|
||||
break
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
im = im.transpose(1, 2, 0)
|
||||
mosaic[y:y + h, x:x + w, :] = im
|
||||
mosaic[y:y + h, x:x + w, :] = images[i].transpose(1, 2, 0)
|
||||
|
||||
# Resize (optional)
|
||||
scale = max_size / ns / max(h, w)
|
||||
|
|
@ -622,7 +619,7 @@ def plot_images(images,
|
|||
# Annotate
|
||||
fs = int((h + w) * ns * 0.01) # font size
|
||||
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
||||
for i in range(i + 1):
|
||||
for i in range(bs):
|
||||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
||||
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
||||
if paths:
|
||||
|
|
@ -699,9 +696,12 @@ def plot_images(images,
|
|||
with contextlib.suppress(Exception):
|
||||
im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
|
||||
annotator.fromarray(im)
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
if save:
|
||||
annotator.im.save(fname) # save
|
||||
if on_plot:
|
||||
on_plot(fname)
|
||||
else:
|
||||
return np.asarray(annotator.im)
|
||||
|
||||
|
||||
@plt_settings()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue