ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: uwer <uwe.rosebrock@gmail.com>
Co-authored-by: Uwe Rosebrock <ro260@csiro.au>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
Glenn Jocher 2024-01-08 23:36:29 +01:00 committed by GitHub
parent e76754eab0
commit 783033fa6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 387 additions and 76 deletions

View file

@ -0,0 +1,3 @@
from .utils import plot_query_result
__all__ = ['plot_query_result']

View file

@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
class ExplorerDataset(YOLODataset):
@ -58,7 +58,7 @@ class Explorer:
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
@ -112,8 +112,7 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
@ -159,10 +158,7 @@ class Explorer:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
@ -189,16 +185,19 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
)
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
@ -228,7 +227,10 @@ class Explorer:
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
if len(result) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
@ -255,6 +257,8 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
@ -288,7 +292,10 @@ class Explorer:
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
if len(similar) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
@ -299,7 +306,7 @@ class Explorer:
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
vector search. Defaults: None.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
@ -401,6 +408,32 @@ class Explorer:
return img if isinstance(img, list) else [img]
def ask_ai(self, query):
"""
Ask AI a question.
Args:
query (str): Question to ask.
Returns:
Answer from AI.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
```
"""
result = prompt_sql_query(query)
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error(e)
return None
return df
def visualize(self, result):
"""
Visualize the results of a query.

View file

@ -1,11 +1,13 @@
import time
from threading import Thread
import pandas as pd
from ultralytics import Explorer
from ultralytics.utils import ROOT
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements('streamlit')
check_requirements('streamlit>=1.29.0')
check_requirements('streamlit-select>=0.2')
import streamlit as st
from streamlit_select import image_select
@ -35,9 +37,9 @@ def init_explorer_form():
with st.form(key='explorer_init_form'):
col1, col2 = st.columns(2)
with col1:
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
with col2:
model = st.selectbox('Select model', models, key='model')
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.form_submit_button('Explore', on_click=_get_explorer)
@ -47,11 +49,23 @@ def query_form():
with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
def ai_query_form():
with st.form('ai_query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
def find_similar_imgs(imgs):
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
@ -64,12 +78,12 @@ def similarity_form(selected_imgs):
with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
limit = st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
with subcol2:
disabled = not len(selected_imgs)
@ -95,6 +109,7 @@ def similarity_form(selected_imgs):
def run_sql_query():
st.session_state['error'] = None
query = st.session_state.get('query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
@ -102,9 +117,26 @@ def run_sql_query():
st.session_state['imgs'] = res.to_pydict()['im_file']
def run_ai_query():
if not SETTINGS['openai_api_key']:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
return
st.session_state['imgs'] = res['im_file'].to_list()
def reset_explorer():
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
def utralytics_explorer_docs_callback():
@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback():
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
def layout():
@ -129,9 +161,12 @@ def layout():
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs = len(imgs)
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
@ -159,6 +194,7 @@ def layout():
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
selected_imgs = image_select(

View file

@ -1,9 +1,14 @@
import getpass
from typing import List
import cv2
import numpy as np
import pandas as pd
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images
@ -47,15 +52,16 @@ def sanitize_batch(batch, dataset_info):
return batch
def plot_similar_images(similar_set, plot_labels=True):
def plot_query_result(similar_set, plot_labels=True):
"""
Plot images from the similar set.
Args:
similar_set (list): Pyarrow table containing the similar data points
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_pydict()
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
@ -102,3 +108,61 @@ def plot_similar_images(similar_set, plot_labels=True):
max_subplots=len(images),
save=False,
threaded=False)
def prompt_sql_query(query):
check_requirements('openai>=1.6.1')
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
messages = [
{
'role':
'system',
'content':
'''
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
Schema:
im_file: string not null
labels: list<item: string> not null
child 0, item: string
cls: list<item: int64> not null
child 0, item: int64
bboxes: list<item: list<item: double>> not null
child 0, item: list<item: double>
child 0, item: double
masks: list<item: list<item: list<item: int64>>> not null
child 0, item: list<item: list<item: int64>>
child 0, item: list<item: int64>
child 0, item: int64
keypoints: list<item: list<item: list<item: double>>> not null
child 0, item: list<item: list<item: double>>
child 0, item: list<item: double>
child 0, item: double
vector: fixed_size_list<item: float>[256] not null
child 0, item: float
Some details about the schema:
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
in each image
- the "cls" column contains the integer values on these classes that map them the labels
Example of a correct query:
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
return response.choices[0].message.content