ultralytics 8.0.236 dataset semantic & SQL search API (#7136)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
This commit is contained in:
Ayush Chaurasia 2024-01-07 00:23:25 +05:30 committed by GitHub
parent 40a5c0abe7
commit aca8eb1fd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 1749 additions and 192 deletions

View file

View file

@ -0,0 +1,403 @@
from io import BytesIO
from pathlib import Path
from typing import List
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from PIL import Image
from tqdm import tqdm
from ultralytics.data.augment import Format
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
class ExplorerDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, data=data, **kwargs)
# NOTE: Load the image directly without any resize operations.
def load_image(self, i):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
if im is None: # not cached in RAM
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
h0, w0 = im.shape[:2] # orig hw
return im, (h0, w0), im.shape[:2]
return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None):
transforms = Format(
bbox_format='xyxy',
normalize=False,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
return transforms
class Explorer:
def __init__(self, data='coco128.yaml', model='yolov8n.pt', uri='~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
self.table_name = Path(data).name.lower() + '_' + model.lower()
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
) # Use this name and append thres and top_k to reuse the table
self.model = YOLO(model)
self.data = data # None
self.choice_set = None
self.table = None
self.progress = 0
def create_embeddings_table(self, force=False, split='train'):
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
Args:
force (bool): Whether to overwrite the existing table or not. Defaults to False.
split (str): Split of the dataset to use. Defaults to 'train'.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
```
"""
if self.table is not None and not force:
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
return
if self.table_name in self.connection.table_names() and not force:
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
self.table = self.connection.open_table(self.table_name)
self.progress = 1
return
if self.data is None:
raise ValueError('Data must be provided to create embeddings table')
data_info = check_det_dataset(self.data)
if split not in data_info:
raise ValueError(
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
)
choice_set = data_info[split]
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
self.choice_set = choice_set
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
self.model,
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
self.table = table
def _yield_batches(self, dataset, data_info, model, exclude_keys: List):
# Implement Batching
for i in tqdm(range(len(dataset))):
self.progress = float(i + 1) / len(dataset)
batch = dataset[i]
for k in exclude_keys:
batch.pop(k, None)
batch = sanitize_batch(batch, data_info)
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
yield [batch]
def query(self, imgs=None, limit=25):
"""
Query the table for similar images. Accepts a single image or a list of images.
Args:
imgs (str or list): Path to the image or a list of paths to the images.
limit (int): Number of results to return.
Returns:
An arrow table containing the results. Supports converting to:
- pandas dataframe: `result.to_pandas()`
- dict of lists: `result.to_pydict()`
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.query(img='https://ultralytics.com/images/zidane.jpg')
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
query = self.table.search(embeds).limit(limit).to_arrow()
return query
def sql_query(self, query, return_type='pandas'):
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
Args:
query (str): SQL query to run.
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns:
An arrow table containing the results.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
result = exp.sql_query(query)
```
"""
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
rs = duckdb.sql(query)
if return_type == 'pandas':
return rs.df()
elif return_type == 'arrow':
return rs.arrow()
def plot_sql_query(self, query, labels=True):
"""
Plot the results of a SQL-Like query on the table.
Args:
query (str): SQL query to run.
labels (bool): Whether to plot the labels or not.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
query = 'SELECT * FROM table WHERE labels LIKE "%person%"'
result = exp.plot_sql_query(query)
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
img = Image.fromarray(img)
return img
def get_similar(self, img=None, idx=None, limit=25, return_type='pandas'):
"""
Query the table for similar images. Accepts a single image or a list of images.
Args:
img (str or list): Path to the image or a list of paths to the images.
idx (int or list): Index of the image in the table or a list of indexes.
limit (int): Number of results to return. Defaults to 25.
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
Returns:
A table or pandas dataframe containing the results.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == 'pandas':
return similar.to_pandas()
elif return_type == 'arrow':
return similar
def plot_similar(self, img=None, idx=None, limit=25, labels=True):
"""
Plot the similar images. Accepts images or indexes.
Args:
img (str or list): Path to the image or a list of paths to the images.
idx (int or list): Index of the image in the table or a list of indexes.
labels (bool): Whether to plot the labels or not.
limit (int): Number of results to return. Defaults to 25.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
img = Image.fromarray(img)
return img
def similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
are max_dist or closer to the image in the embedding space at a given index.
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
A pandas dataframe containing the similarity index.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
sim_idx = exp.similarity_index()
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
if sim_idx_table_name in self.connection.table_names() and not force:
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
return self.connection.open_table(sim_idx_table_name).to_pandas()
if top_k and not (1.0 >= top_k >= 0.0):
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
if max_dist < 0.0:
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
top_k = max(top_k, 1)
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
im_files = features['im_file']
embeddings = features['vector']
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
def _yield_sim_idx():
for i in tqdm(range(len(embeddings))):
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
yield [{
'idx': i,
'im_file': im_files[i],
'count': len(sim_idx),
'sim_im_files': sim_idx['im_file'].tolist()}]
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
return sim_table.to_pandas()
def plot_similarity_index(self, max_dist=0.2, top_k=None, force=False):
"""
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
max_dist or closer to the image in the embedding space at a given index.
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
running vector search. Defaults to 0.01.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
PIL Image containing the plot.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
exp.plot_similarity_index()
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
sim_count = sim_idx['count'].tolist()
sim_count = np.array(sim_count)
indices = np.arange(len(sim_count))
# Create the bar plot
plt.bar(indices, sim_count)
# Customize the plot (optional)
plt.xlabel('data idx')
plt.ylabel('Count')
plt.title('Similarity Count')
buffer = BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
# Use Pillow to open the image from the buffer
image = Image.open(buffer)
return image
def _check_imgs_or_idxs(self, img, idx):
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
if img is not None and idx is not None:
raise ValueError('Only one of img or idx must be provided.')
if idx is not None:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = img if isinstance(img, list) else [img]
return img
def visualize(self, result):
"""
Visualize the results of a query.
Args:
result (arrow table): Arrow table containing the results of a query.
"""
# TODO:
pass
def generate_report(self, result):
"""Generate a report of the dataset."""
pass

View file

@ -0,0 +1,178 @@
import time
from threading import Thread
from ultralytics import Explorer
from ultralytics.utils import ROOT
from ultralytics.utils.checks import check_requirements
check_requirements('streamlit')
check_requirements('streamlit-select>=0.2')
import streamlit as st
from streamlit_select import image_select
def _get_explorer():
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
thread = Thread(target=exp.create_embeddings_table,
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
thread.start()
progress_bar = st.progress(0, text='Creating embeddings table...')
while exp.progress < 1:
time.sleep(0.1)
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
thread.join()
st.session_state['explorer'] = exp
progress_bar.empty()
def init_explorer_form():
datasets = ROOT / 'cfg' / 'datasets'
ds = [d.name for d in datasets.glob('*.yaml')]
models = [
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
'yolov8l-pose.pt', 'yolov8x-pose.pt']
with st.form(key='explorer_init_form'):
col1, col2 = st.columns(2)
with col1:
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
with col2:
model = st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.form_submit_button('Explore', on_click=_get_explorer)
def query_form():
with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
def find_similar_imgs(imgs):
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
paths = similar.to_pydict()['im_file']
st.session_state['imgs'] = paths
def similarity_form(selected_imgs):
st.write('Similarity Search')
with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
limit = st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
with subcol2:
disabled = not len(selected_imgs)
st.write('Selected: ', len(selected_imgs))
st.form_submit_button(
'Search',
disabled=disabled,
on_click=find_similar_imgs,
args=(selected_imgs, ),
)
if disabled:
st.error('Select at least one image to search.')
# def persist_reset_form():
# with st.form("persist_reset"):
# col1, col2 = st.columns([1, 1])
# with col1:
# st.form_submit_button("Reset", on_click=reset)
#
# with col2:
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
def run_sql_query():
query = st.session_state.get('query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.sql_query(query, return_type='arrow')
st.session_state['imgs'] = res.to_pydict()['im_file']
def reset_explorer():
st.session_state['explorer'] = None
st.session_state['imgs'] = None
def utralytics_explorer_docs_callback():
with st.container(border=True):
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
def layout():
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
if st.session_state.get('explorer') is None:
init_explorer_form()
return
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs = len(imgs)
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
st.write('Max Images Displayed:')
with subcol2:
num = st.number_input('Max Images Displayed',
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key='num_imgs_displayed',
label_visibility='collapsed')
with subcol3:
st.write('Start Index:')
with subcol4:
start_idx = st.number_input('Start Index',
min_value=0,
max_value=total_imgs,
value=0,
key='start_index',
label_visibility='collapsed')
with subcol5:
reset = st.button('Reset', use_container_width=False, key='reset')
if reset:
st.session_state['imgs'] = None
st.experimental_rerun()
query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
selected_imgs = image_select(
f'Total samples: {total_imgs}',
images=imgs_displayed,
use_container_width=False,
# indices=[i for i in range(num)] if select_all else None,
)
with col2:
similarity_form(selected_imgs)
# display_labels = st.checkbox("Labels", value=False, key="display_labels")
utralytics_explorer_docs_callback()
if __name__ == '__main__':
layout()

View file

@ -0,0 +1,103 @@
from pathlib import Path
from typing import List
import cv2
import numpy as np
from ultralytics.data.augment import LetterBox
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images
def get_table_schema(vector_size):
from lancedb.pydantic import LanceModel, Vector
class Schema(LanceModel):
im_file: str
labels: List[str]
cls: List[int]
bboxes: List[List[float]]
masks: List[List[List[int]]]
keypoints: List[List[List[float]]]
vector: Vector(vector_size)
return Schema
def get_sim_index_schema():
from lancedb.pydantic import LanceModel
class Schema(LanceModel):
idx: int
im_file: str
count: int
sim_im_files: List[str]
return Schema
def sanitize_batch(batch, dataset_info):
batch['cls'] = batch['cls'].flatten().int().tolist()
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
batch['bboxes'] = [box for box, _ in box_cls_pair]
batch['cls'] = [cls for _, cls in box_cls_pair]
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
return batch
def plot_similar_images(similar_set, plot_labels=True):
"""
Plot images from the similar set.
Args:
similar_set (list): Pyarrow table containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_pydict()
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
cls = similar_set.get('cls', [])
plot_size = 640
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
for i, imf in enumerate(images):
im = cv2.imread(imf)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
h, w = im.shape[:2]
r = min(plot_size / h, plot_size / w)
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
if plot_labels:
if len(bboxes) > i and len(bboxes[i]) > 0:
box = np.array(bboxes[i], dtype=np.float32)
box[:, [0, 2]] *= r
box[:, [1, 3]] *= r
plot_boxes.append(box)
if len(masks) > i and len(masks[i]) > 0:
mask = np.array(masks[i], dtype=np.uint8)[0]
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
if len(kpts) > i and kpts[i] is not None:
kpt = np.array(kpts[i], dtype=np.float32)
kpt[:, :, :2] *= r
plot_kpts.append(kpt)
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
imgs = np.stack(imgs, axis=0)
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
fname = 'temp_exp_grid.jpg'
plot_images(imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, fname=fname,
max_subplots=len(images)).join()
img = cv2.imread(fname, cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
Path(fname).unlink()
return img_rgb