ultralytics 8.0.236 dataset semantic & SQL search API (#7136)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
parent
40a5c0abe7
commit
aca8eb1fd4
27 changed files with 1749 additions and 192 deletions
|
|
@ -1,7 +1,8 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.235'
|
||||
__version__ = '8.0.236'
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
from ultralytics.models.nas import NAS
|
||||
|
|
@ -9,4 +10,4 @@ from ultralytics.utils import SETTINGS as settings
|
|||
from ultralytics.utils.checks import check_yolo as checks
|
||||
from ultralytics.utils.downloads import download
|
||||
|
||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings'
|
||||
__all__ = '__version__', 'YOLO', 'NAS', 'SAM', 'FastSAM', 'RTDETR', 'checks', 'download', 'settings', 'Explorer'
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
|
@ -56,6 +57,9 @@ CLI_HELP_MSG = \
|
|||
4. Export a YOLOv8n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolov8n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
6. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
|
||||
yolo explorer
|
||||
|
||||
5. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
|
|
@ -297,6 +301,12 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|||
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
|
||||
|
||||
|
||||
def handle_explorer():
|
||||
"""Open the Ultralytics Explorer GUI."""
|
||||
checks.check_requirements('streamlit')
|
||||
subprocess.run(['streamlit', 'run', ROOT / 'data/explorer/gui/dash.py', '--server.maxMessageSize', '2048'])
|
||||
|
||||
|
||||
def parse_key_value_pair(pair):
|
||||
"""Parse one 'key=value' pair and return key and value."""
|
||||
k, v = pair.split('=', 1) # split on first '=' sign
|
||||
|
|
@ -348,7 +358,8 @@ def entrypoint(debug=''):
|
|||
'cfg': lambda: yaml_print(DEFAULT_CFG_PATH),
|
||||
'hub': lambda: handle_yolo_hub(args[1:]),
|
||||
'login': lambda: handle_yolo_hub(args),
|
||||
'copy-cfg': copy_default_cfg}
|
||||
'copy-cfg': copy_default_cfg,
|
||||
'explorer': lambda: handle_explorer()}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
||||
# Define common misuses of special commands, i.e. -h, -help, --help
|
||||
|
|
|
|||
0
ultralytics/data/explorer/__init__.py
Normal file
0
ultralytics/data/explorer/__init__.py
Normal file
403
ultralytics/data/explorer/explorer.py
Normal file
403
ultralytics/data/explorer/explorer.py
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.data.augment import Format
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
from ultralytics.utils import LOGGER, checks
|
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
|
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
# NOTE: Load the image directly without any resize operations.
|
||||
def load_image(self, i):
|
||||
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f'Image Not Found {f}')
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
transforms = Format(
|
||||
bbox_format='xyxy',
|
||||
normalize=False,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
return transforms
|
||||
|
||||
|
||||
class Explorer:
|
||||
|
||||
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
|
||||
checks.check_requirements(['lancedb', 'duckdb'])
|
||||
import lancedb
|
||||
|
||||
self.connection = lancedb.connect(uri)
|
||||
self.table_name = Path(data).name.lower() + '_' + model.lower()
|
||||
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
|
||||
) # Use this name and append thres and top_k to reuse the table
|
||||
self.model = YOLO(model)
|
||||
self.data = data # None
|
||||
self.choice_set = None
|
||||
|
||||
self.table = None
|
||||
self.progress = 0
|
||||
|
||||
def create_embeddings_table(self, force=False, split='train'):
|
||||
"""
|
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
||||
already exists. Pass force=True to overwrite the existing table.
|
||||
|
||||
Args:
|
||||
force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
||||
split (str): Split of the dataset to use. Defaults to 'train'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
```
|
||||
"""
|
||||
if self.table is not None and not force:
|
||||
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
return
|
||||
if self.table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
self.table = self.connection.open_table(self.table_name)
|
||||
self.progress = 1
|
||||
return
|
||||
if self.data is None:
|
||||
raise ValueError('Data must be provided to create embeddings table')
|
||||
|
||||
data_info = check_det_dataset(self.data)
|
||||
if split not in data_info:
|
||||
raise ValueError(
|
||||
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
|
||||
)
|
||||
|
||||
choice_set = data_info[split]
|
||||
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
||||
self.choice_set = choice_set
|
||||
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
||||
|
||||
# Create the table schema
|
||||
batch = dataset[0]
|
||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
|
||||
Schema = get_table_schema(vector_size)
|
||||
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
|
||||
table.add(
|
||||
self._yield_batches(dataset,
|
||||
data_info,
|
||||
self.model,
|
||||
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
|
||||
|
||||
self.table = table
|
||||
|
||||
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
|
||||
# Implement Batching
|
||||
for i in tqdm(range(len(dataset))):
|
||||
self.progress = float(i + 1) / len(dataset)
|
||||
batch = dataset[i]
|
||||
for k in exclude_keys:
|
||||
batch.pop(k, None)
|
||||
batch = sanitize_batch(batch, data_info)
|
||||
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
|
||||
yield [batch]
|
||||
|
||||
def query(self, imgs=None, limit=25):
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
Args:
|
||||
imgs (str or list): Path to the image or a list of paths to the images.
|
||||
limit (int): Number of results to return.
|
||||
|
||||
Returns:
|
||||
An arrow table containing the results. Supports converting to:
|
||||
- pandas dataframe: `result.to_pandas()`
|
||||
- dict of lists: `result.to_pydict()`
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
if isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
elif isinstance(imgs, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
|
||||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
query = self.table.search(embeds).limit(limit).to_arrow()
|
||||
return query
|
||||
|
||||
def sql_query(self, query, return_type='pandas'):
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
Args:
|
||||
query (str): SQL query to run.
|
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
||||
|
||||
Returns:
|
||||
An arrow table containing the results.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
|
||||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
import duckdb
|
||||
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||
table = self.table.to_arrow() # noqa
|
||||
if not query.startswith('SELECT') and not query.startswith('WHERE'):
|
||||
raise ValueError(
|
||||
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
|
||||
if query.startswith('WHERE'):
|
||||
query = f"SELECT * FROM 'table' {query}"
|
||||
LOGGER.info(f'Running query: {query}')
|
||||
|
||||
rs = duckdb.sql(query)
|
||||
if return_type == 'pandas':
|
||||
return rs.df()
|
||||
elif return_type == 'arrow':
|
||||
return rs.arrow()
|
||||
|
||||
def plot_sql_query(self, query, labels=True):
|
||||
"""
|
||||
Plot the results of a SQL-Like query on the table.
|
||||
Args:
|
||||
query (str): SQL query to run.
|
||||
labels (bool): Whether to plot the labels or not.
|
||||
|
||||
Returns:
|
||||
PIL Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
|
||||
result = exp.plot_sql_query(query)
|
||||
```
|
||||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
|
||||
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
Args:
|
||||
img (str or list): Path to the image or a list of paths to the images.
|
||||
idx (int or list): Index of the image in the table or a list of indexes.
|
||||
limit (int): Number of results to return. Defaults to 25.
|
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
||||
|
||||
Returns:
|
||||
A table or pandas dataframe containing the results.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
img = self._check_imgs_or_idxs(img, idx)
|
||||
similar = self.query(img, limit=limit)
|
||||
|
||||
if return_type == 'pandas':
|
||||
return similar.to_pandas()
|
||||
elif return_type == 'arrow':
|
||||
return similar
|
||||
|
||||
def plot_similar(self, img=None, idx=None, limit=25, labels=True):
|
||||
"""
|
||||
Plot the similar images. Accepts images or indexes.
|
||||
|
||||
Args:
|
||||
img (str or list): Path to the image or a list of paths to the images.
|
||||
idx (int or list): Index of the image in the table or a list of indexes.
|
||||
labels (bool): Whether to plot the labels or not.
|
||||
limit (int): Number of results to return. Defaults to 25.
|
||||
|
||||
Returns:
|
||||
PIL Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
img = Image.fromarray(img)
|
||||
return img
|
||||
|
||||
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
"""
|
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
||||
are max_dist or closer to the image in the embedding space at a given index.
|
||||
|
||||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
||||
vector search. Defaults to 0.01.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
A pandas dataframe containing the similarity index.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
sim_idx = exp.similarity_index()
|
||||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
|
||||
if sim_idx_table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
|
||||
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
||||
|
||||
if top_k and not (1.0 >= top_k >= 0.0):
|
||||
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
|
||||
if max_dist < 0.0:
|
||||
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
|
||||
|
||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
||||
top_k = max(top_k, 1)
|
||||
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
|
||||
im_files = features['im_file']
|
||||
embeddings = features['vector']
|
||||
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
|
||||
|
||||
def _yield_sim_idx():
|
||||
for i in tqdm(range(len(embeddings))):
|
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
|
||||
yield [{
|
||||
'idx': i,
|
||||
'im_file': im_files[i],
|
||||
'count': len(sim_idx),
|
||||
'sim_im_files': sim_idx['im_file'].tolist()}]
|
||||
|
||||
sim_table.add(_yield_sim_idx())
|
||||
self.sim_index = sim_table
|
||||
|
||||
return sim_table.to_pandas()
|
||||
|
||||
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
|
||||
"""
|
||||
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
||||
max_dist or closer to the image in the embedding space at a given index.
|
||||
|
||||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
||||
running vector search. Defaults to 0.01.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
PIL Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
exp.plot_similarity_index()
|
||||
```
|
||||
"""
|
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
||||
sim_count = sim_idx['count'].tolist()
|
||||
sim_count = np.array(sim_count)
|
||||
|
||||
indices = np.arange(len(sim_count))
|
||||
|
||||
# Create the bar plot
|
||||
plt.bar(indices, sim_count)
|
||||
|
||||
# Customize the plot (optional)
|
||||
plt.xlabel('data idx')
|
||||
plt.ylabel('Count')
|
||||
plt.title('Similarity Count')
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format='png')
|
||||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
image = Image.open(buffer)
|
||||
return image
|
||||
|
||||
def _check_imgs_or_idxs(self, img, idx):
|
||||
if img is None and idx is None:
|
||||
raise ValueError('Either img or idx must be provided.')
|
||||
if img is not None and idx is not None:
|
||||
raise ValueError('Only one of img or idx must be provided.')
|
||||
if idx is not None:
|
||||
idx = idx if isinstance(idx, list) else [idx]
|
||||
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
|
||||
|
||||
img = img if isinstance(img, list) else [img]
|
||||
return img
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
Visualize the results of a query.
|
||||
|
||||
Args:
|
||||
result (arrow table): Arrow table containing the results of a query.
|
||||
"""
|
||||
# TODO:
|
||||
pass
|
||||
|
||||
def generate_report(self, result):
|
||||
"""Generate a report of the dataset."""
|
||||
pass
|
||||
0
ultralytics/data/explorer/gui/__init__.py
Normal file
0
ultralytics/data/explorer/gui/__init__.py
Normal file
178
ultralytics/data/explorer/gui/dash.py
Normal file
178
ultralytics/data/explorer/gui/dash.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('streamlit')
|
||||
check_requirements('streamlit-select>=0.2')
|
||||
import streamlit as st
|
||||
from streamlit_select import image_select
|
||||
|
||||
|
||||
def _get_explorer():
|
||||
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
|
||||
thread = Thread(target=exp.create_embeddings_table,
|
||||
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
|
||||
thread.start()
|
||||
progress_bar = st.progress(0, text='Creating embeddings table...')
|
||||
while exp.progress < 1:
|
||||
time.sleep(0.1)
|
||||
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
|
||||
thread.join()
|
||||
st.session_state['explorer'] = exp
|
||||
progress_bar.empty()
|
||||
|
||||
|
||||
def init_explorer_form():
|
||||
datasets = ROOT / 'cfg' / 'datasets'
|
||||
ds = [d.name for d in datasets.glob('*.yaml')]
|
||||
models = [
|
||||
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
|
||||
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
|
||||
'yolov8l-pose.pt', 'yolov8x-pose.pt']
|
||||
with st.form(key='explorer_init_form'):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
with col2:
|
||||
model = st.selectbox('Select model', models, key='model')
|
||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
|
||||
|
||||
st.form_submit_button('Explore', on_click=_get_explorer)
|
||||
|
||||
|
||||
def query_form():
|
||||
with st.form('query_form'):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
|
||||
with col2:
|
||||
st.form_submit_button('Query', on_click=run_sql_query)
|
||||
|
||||
|
||||
def find_similar_imgs(imgs):
|
||||
exp = st.session_state['explorer']
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
|
||||
paths = similar.to_pydict()['im_file']
|
||||
st.session_state['imgs'] = paths
|
||||
|
||||
|
||||
def similarity_form(selected_imgs):
|
||||
st.write('Similarity Search')
|
||||
with st.form('similarity_form'):
|
||||
subcol1, subcol2 = st.columns([1, 1])
|
||||
with subcol1:
|
||||
limit = st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
|
||||
with subcol2:
|
||||
disabled = not len(selected_imgs)
|
||||
st.write('Selected: ', len(selected_imgs))
|
||||
st.form_submit_button(
|
||||
'Search',
|
||||
disabled=disabled,
|
||||
on_click=find_similar_imgs,
|
||||
args=(selected_imgs, ),
|
||||
)
|
||||
if disabled:
|
||||
st.error('Select at least one image to search.')
|
||||
|
||||
|
||||
# def persist_reset_form():
|
||||
# with st.form("persist_reset"):
|
||||
# col1, col2 = st.columns([1, 1])
|
||||
# with col1:
|
||||
# st.form_submit_button("Reset", on_click=reset)
|
||||
#
|
||||
# with col2:
|
||||
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
|
||||
|
||||
|
||||
def run_sql_query():
|
||||
query = st.session_state.get('query')
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
res = exp.sql_query(query, return_type='arrow')
|
||||
st.session_state['imgs'] = res.to_pydict()['im_file']
|
||||
|
||||
|
||||
def reset_explorer():
|
||||
st.session_state['explorer'] = None
|
||||
st.session_state['imgs'] = None
|
||||
|
||||
|
||||
def utralytics_explorer_docs_callback():
|
||||
with st.container(border=True):
|
||||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
|
||||
width=100)
|
||||
st.markdown(
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
|
||||
unsafe_allow_html=True,
|
||||
help=None)
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
|
||||
|
||||
|
||||
def layout():
|
||||
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
|
||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
||||
|
||||
if st.session_state.get('explorer') is None:
|
||||
init_explorer_form()
|
||||
return
|
||||
|
||||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
|
||||
exp = st.session_state.get('explorer')
|
||||
col1, col2 = st.columns([0.75, 0.25], gap='small')
|
||||
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
total_imgs = len(imgs)
|
||||
with col1:
|
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||
with subcol1:
|
||||
st.write('Max Images Displayed:')
|
||||
with subcol2:
|
||||
num = st.number_input('Max Images Displayed',
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=min(500, total_imgs),
|
||||
key='num_imgs_displayed',
|
||||
label_visibility='collapsed')
|
||||
with subcol3:
|
||||
st.write('Start Index:')
|
||||
with subcol4:
|
||||
start_idx = st.number_input('Start Index',
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=0,
|
||||
key='start_index',
|
||||
label_visibility='collapsed')
|
||||
with subcol5:
|
||||
reset = st.button('Reset', use_container_width=False, key='reset')
|
||||
if reset:
|
||||
st.session_state['imgs'] = None
|
||||
st.experimental_rerun()
|
||||
|
||||
query_form()
|
||||
if total_imgs:
|
||||
imgs_displayed = imgs[start_idx:start_idx + num]
|
||||
selected_imgs = image_select(
|
||||
f'Total samples: {total_imgs}',
|
||||
images=imgs_displayed,
|
||||
use_container_width=False,
|
||||
# indices=[i for i in range(num)] if select_all else None,
|
||||
)
|
||||
|
||||
with col2:
|
||||
similarity_form(selected_imgs)
|
||||
# display_labels = st.checkbox("Labels", value=False, key="display_labels")
|
||||
utralytics_explorer_docs_callback()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
layout()
|
||||
103
ultralytics/data/explorer/utils.py
Normal file
103
ultralytics/data/explorer/utils.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
def get_table_schema(vector_size):
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
class Schema(LanceModel):
|
||||
im_file: str
|
||||
labels: List[str]
|
||||
cls: List[int]
|
||||
bboxes: List[List[float]]
|
||||
masks: List[List[List[int]]]
|
||||
keypoints: List[List[List[float]]]
|
||||
vector: Vector(vector_size)
|
||||
|
||||
return Schema
|
||||
|
||||
|
||||
def get_sim_index_schema():
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
class Schema(LanceModel):
|
||||
idx: int
|
||||
im_file: str
|
||||
count: int
|
||||
sim_im_files: List[str]
|
||||
|
||||
return Schema
|
||||
|
||||
|
||||
def sanitize_batch(batch, dataset_info):
|
||||
batch['cls'] = batch['cls'].flatten().int().tolist()
|
||||
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
|
||||
batch['bboxes'] = [box for box, _ in box_cls_pair]
|
||||
batch['cls'] = [cls for _, cls in box_cls_pair]
|
||||
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
|
||||
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
|
||||
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def plot_similar_images(similar_set, plot_labels=True):
|
||||
"""
|
||||
Plot images from the similar set.
|
||||
|
||||
Args:
|
||||
similar_set (list): Pyarrow table containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
similar_set = similar_set.to_pydict()
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
images = similar_set.get('im_file', [])
|
||||
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
|
||||
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
|
||||
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
|
||||
cls = similar_set.get('cls', [])
|
||||
|
||||
plot_size = 640
|
||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
||||
for i, imf in enumerate(images):
|
||||
im = cv2.imread(imf)
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
h, w = im.shape[:2]
|
||||
r = min(plot_size / h, plot_size / w)
|
||||
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
|
||||
if plot_labels:
|
||||
if len(bboxes) > i and len(bboxes[i]) > 0:
|
||||
box = np.array(bboxes[i], dtype=np.float32)
|
||||
box[:, [0, 2]] *= r
|
||||
box[:, [1, 3]] *= r
|
||||
plot_boxes.append(box)
|
||||
if len(masks) > i and len(masks[i]) > 0:
|
||||
mask = np.array(masks[i], dtype=np.uint8)[0]
|
||||
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
|
||||
if len(kpts) > i and kpts[i] is not None:
|
||||
kpt = np.array(kpts[i], dtype=np.float32)
|
||||
kpt[:, :, :2] *= r
|
||||
plot_kpts.append(kpt)
|
||||
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
|
||||
imgs = np.stack(imgs, axis=0)
|
||||
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
|
||||
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
|
||||
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
|
||||
batch_idx = np.concatenate(batch_idx, axis=0)
|
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
||||
|
||||
fname = 'temp_exp_grid.jpg'
|
||||
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
|
||||
max_subplots=len(images)).join()
|
||||
img = cv2.imread(fname, cv2.IMREAD_COLOR)
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
Path(fname).unlink()
|
||||
return img_rgb
|
||||
|
|
@ -579,7 +579,8 @@ def plot_images(images,
|
|||
paths=None,
|
||||
fname='images.jpg',
|
||||
names=None,
|
||||
on_plot=None):
|
||||
on_plot=None,
|
||||
max_subplots=16):
|
||||
"""Plot image grid with labels."""
|
||||
if isinstance(images, torch.Tensor):
|
||||
images = images.cpu().float().numpy()
|
||||
|
|
@ -595,7 +596,7 @@ def plot_images(images,
|
|||
batch_idx = batch_idx.cpu().numpy()
|
||||
|
||||
max_size = 1920 # max image size
|
||||
max_subplots = 16 # max image subplots, i.e. 4x4
|
||||
max_subplots = max_subplots # max image subplots, i.e. 4x4
|
||||
bs, _, h, w = images.shape # batch size, _, height, width
|
||||
bs = min(bs, max_subplots) # limit plot images
|
||||
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
||||
|
|
@ -685,7 +686,7 @@ def plot_images(images,
|
|||
image_masks = np.where(image_masks == index, 1.0, 0.0)
|
||||
|
||||
im = np.asarray(annotator.im).copy()
|
||||
for j, box in enumerate(boxes.T.tolist()):
|
||||
for j in range(len(image_masks)):
|
||||
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
||||
color = colors(classes[j])
|
||||
mh, mw = image_masks[j].shape
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue