Explorer Cleanup (#7364)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
aca8eb1fd4
commit
ed73c0fedc
15 changed files with 585 additions and 63 deletions
|
|
@ -40,7 +40,7 @@ class ExplorerDataset(YOLODataset):
|
|||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
transforms = Format(
|
||||
return Format(
|
||||
bbox_format='xyxy',
|
||||
normalize=False,
|
||||
return_mask=self.use_segments,
|
||||
|
|
@ -49,7 +49,6 @@ class ExplorerDataset(YOLODataset):
|
|||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
return transforms
|
||||
|
||||
|
||||
class Explorer:
|
||||
|
|
@ -161,8 +160,7 @@ class Explorer:
|
|||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
query = self.table.search(embeds).limit(limit).to_arrow()
|
||||
return query
|
||||
return self.table.search(embeds).limit(limit).to_arrow()
|
||||
|
||||
def sql_query(self, query, return_type='pandas'):
|
||||
"""
|
||||
|
|
@ -223,8 +221,7 @@ class Explorer:
|
|||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
|
||||
"""
|
||||
|
|
@ -276,8 +273,7 @@ class Explorer:
|
|||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
"""
|
||||
|
|
@ -331,7 +327,6 @@ class Explorer:
|
|||
|
||||
sim_table.add(_yield_sim_idx())
|
||||
self.sim_index = sim_table
|
||||
|
||||
return sim_table.to_pandas()
|
||||
|
||||
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
|
|
@ -373,8 +368,7 @@ class Explorer:
|
|||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
image = Image.open(buffer)
|
||||
return image
|
||||
return Image.open(buffer)
|
||||
|
||||
def _check_imgs_or_idxs(self, img, idx):
|
||||
if img is None and idx is None:
|
||||
|
|
@ -385,8 +379,7 @@ class Explorer:
|
|||
idx = idx if isinstance(idx, list) else [idx]
|
||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
|
||||
|
||||
img = img if isinstance(img, list) else [img]
|
||||
return img
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue