ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: uwer <uwe.rosebrock@gmail.com> Co-authored-by: Uwe Rosebrock <ro260@csiro.au> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
parent
e76754eab0
commit
783033fa6b
19 changed files with 387 additions and 76 deletions
|
|
@ -0,0 +1,3 @@
|
|||
from .utils import plot_query_result
|
||||
|
||||
__all__ = ['plot_query_result']
|
||||
|
|
@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
|
|||
from ultralytics.models.yolo.model import YOLO
|
||||
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
|
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
|
|
@ -58,7 +58,7 @@ class Explorer:
|
|||
data: Union[str, Path] = 'coco128.yaml',
|
||||
model: str = 'yolov8n.pt',
|
||||
uri: str = '~/ultralytics/explorer') -> None:
|
||||
checks.check_requirements(['lancedb', 'duckdb'])
|
||||
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
|
||||
import lancedb
|
||||
|
||||
self.connection = lancedb.connect(uri)
|
||||
|
|
@ -112,8 +112,7 @@ class Explorer:
|
|||
# Create the table schema
|
||||
batch = dataset[0]
|
||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
|
||||
Schema = get_table_schema(vector_size)
|
||||
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
|
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
|
||||
table.add(
|
||||
self._yield_batches(dataset,
|
||||
data_info,
|
||||
|
|
@ -159,10 +158,7 @@ class Explorer:
|
|||
raise ValueError('Table is not created. Please create the table first.')
|
||||
if isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
elif isinstance(imgs, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
|
||||
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
|
||||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
|
|
@ -189,16 +185,19 @@ class Explorer:
|
|||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
import duckdb
|
||||
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||
table = self.table.to_arrow() # noqa
|
||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
||||
if not query.startswith('SELECT') and not query.startswith('WHERE'):
|
||||
raise ValueError(
|
||||
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
|
||||
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
|
||||
)
|
||||
if query.startswith('WHERE'):
|
||||
query = f"SELECT * FROM 'table' {query}"
|
||||
LOGGER.info(f'Running query: {query}')
|
||||
|
|
@ -228,7 +227,10 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
if len(result) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
return None
|
||||
img = plot_query_result(result, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self,
|
||||
|
|
@ -255,6 +257,8 @@ class Explorer:
|
|||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
img = self._check_imgs_or_idxs(img, idx)
|
||||
similar = self.query(img, limit=limit)
|
||||
|
||||
|
|
@ -288,7 +292,10 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
if len(similar) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
return None
|
||||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
||||
|
|
@ -299,7 +306,7 @@ class Explorer:
|
|||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
||||
vector search. Defaults to 0.01.
|
||||
vector search. Defaults: None.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
|
@ -401,6 +408,32 @@ class Explorer:
|
|||
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
def ask_ai(self, query):
|
||||
"""
|
||||
Ask AI a question.
|
||||
|
||||
Args:
|
||||
query (str): Question to ask.
|
||||
|
||||
Returns:
|
||||
Answer from AI.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
||||
```
|
||||
"""
|
||||
result = prompt_sql_query(query)
|
||||
try:
|
||||
df = self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
return df
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
Visualize the results of a query.
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT
|
||||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('streamlit')
|
||||
check_requirements('streamlit>=1.29.0')
|
||||
check_requirements('streamlit-select>=0.2')
|
||||
import streamlit as st
|
||||
from streamlit_select import image_select
|
||||
|
|
@ -35,9 +37,9 @@ def init_explorer_form():
|
|||
with st.form(key='explorer_init_form'):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
with col2:
|
||||
model = st.selectbox('Select model', models, key='model')
|
||||
st.selectbox('Select model', models, key='model')
|
||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
|
||||
|
||||
st.form_submit_button('Explore', on_click=_get_explorer)
|
||||
|
|
@ -47,11 +49,23 @@ def query_form():
|
|||
with st.form('query_form'):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
|
||||
st.text_input('Query',
|
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||
label_visibility='collapsed',
|
||||
key='query')
|
||||
with col2:
|
||||
st.form_submit_button('Query', on_click=run_sql_query)
|
||||
|
||||
|
||||
def ai_query_form():
|
||||
with st.form('ai_query_form'):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
|
||||
with col2:
|
||||
st.form_submit_button('Ask AI', on_click=run_ai_query)
|
||||
|
||||
|
||||
def find_similar_imgs(imgs):
|
||||
exp = st.session_state['explorer']
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
|
||||
|
|
@ -64,12 +78,12 @@ def similarity_form(selected_imgs):
|
|||
with st.form('similarity_form'):
|
||||
subcol1, subcol2 = st.columns([1, 1])
|
||||
with subcol1:
|
||||
limit = st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
|
||||
with subcol2:
|
||||
disabled = not len(selected_imgs)
|
||||
|
|
@ -95,6 +109,7 @@ def similarity_form(selected_imgs):
|
|||
|
||||
|
||||
def run_sql_query():
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('query')
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
|
|
@ -102,9 +117,26 @@ def run_sql_query():
|
|||
st.session_state['imgs'] = res.to_pydict()['im_file']
|
||||
|
||||
|
||||
def run_ai_query():
|
||||
if not SETTINGS['openai_api_key']:
|
||||
st.session_state[
|
||||
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
return
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('ai_query')
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
|
||||
return
|
||||
st.session_state['imgs'] = res['im_file'].to_list()
|
||||
|
||||
|
||||
def reset_explorer():
|
||||
st.session_state['explorer'] = None
|
||||
st.session_state['imgs'] = None
|
||||
st.session_state['error'] = None
|
||||
|
||||
|
||||
def utralytics_explorer_docs_callback():
|
||||
|
|
@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback():
|
|||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
|
||||
width=100)
|
||||
st.markdown(
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
||||
unsafe_allow_html=True,
|
||||
help=None)
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
|
||||
|
||||
|
||||
def layout():
|
||||
|
|
@ -129,9 +161,12 @@ def layout():
|
|||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
|
||||
exp = st.session_state.get('explorer')
|
||||
col1, col2 = st.columns([0.75, 0.25], gap='small')
|
||||
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
total_imgs = len(imgs)
|
||||
imgs = []
|
||||
if st.session_state.get('error'):
|
||||
st.error(st.session_state['error'])
|
||||
else:
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
total_imgs, selected_imgs = len(imgs), []
|
||||
with col1:
|
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||
with subcol1:
|
||||
|
|
@ -159,6 +194,7 @@ def layout():
|
|||
st.experimental_rerun()
|
||||
|
||||
query_form()
|
||||
ai_query_form()
|
||||
if total_imgs:
|
||||
imgs_displayed = imgs[start_idx:start_idx + num]
|
||||
selected_imgs = image_select(
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
import getpass
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER as logger
|
||||
from ultralytics.utils import SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
|
@ -47,15 +52,16 @@ def sanitize_batch(batch, dataset_info):
|
|||
return batch
|
||||
|
||||
|
||||
def plot_similar_images(similar_set, plot_labels=True):
|
||||
def plot_query_result(similar_set, plot_labels=True):
|
||||
"""
|
||||
Plot images from the similar set.
|
||||
|
||||
Args:
|
||||
similar_set (list): Pyarrow table containing the similar data points
|
||||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
similar_set = similar_set.to_pydict()
|
||||
similar_set = similar_set.to_dict(
|
||||
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
images = similar_set.get('im_file', [])
|
||||
|
|
@ -102,3 +108,61 @@ def plot_similar_images(similar_set, plot_labels=True):
|
|||
max_subplots=len(images),
|
||||
save=False,
|
||||
threaded=False)
|
||||
|
||||
|
||||
def prompt_sql_query(query):
|
||||
check_requirements('openai>=1.6.1')
|
||||
from openai import OpenAI
|
||||
|
||||
if not SETTINGS['openai_api_key']:
|
||||
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
|
||||
openai_api_key = getpass.getpass('OpenAI API key: ')
|
||||
SETTINGS.update({'openai_api_key': openai_api_key})
|
||||
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
|
||||
|
||||
messages = [
|
||||
{
|
||||
'role':
|
||||
'system',
|
||||
'content':
|
||||
'''
|
||||
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
||||
the following schema and a user request. You only need to output the format with fixed selection
|
||||
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
||||
|
||||
Schema:
|
||||
im_file: string not null
|
||||
labels: list<item: string> not null
|
||||
child 0, item: string
|
||||
cls: list<item: int64> not null
|
||||
child 0, item: int64
|
||||
bboxes: list<item: list<item: double>> not null
|
||||
child 0, item: list<item: double>
|
||||
child 0, item: double
|
||||
masks: list<item: list<item: list<item: int64>>> not null
|
||||
child 0, item: list<item: list<item: int64>>
|
||||
child 0, item: list<item: int64>
|
||||
child 0, item: int64
|
||||
keypoints: list<item: list<item: list<item: double>>> not null
|
||||
child 0, item: list<item: list<item: double>>
|
||||
child 0, item: list<item: double>
|
||||
child 0, item: double
|
||||
vector: fixed_size_list<item: float>[256] not null
|
||||
child 0, item: float
|
||||
|
||||
Some details about the schema:
|
||||
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
|
||||
in each image
|
||||
- the "cls" column contains the integer values on these classes that map them the labels
|
||||
|
||||
Example of a correct query:
|
||||
request - Get all data points that contain 2 or more people and at least one dog
|
||||
correct query-
|
||||
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
||||
'''},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': f'{query}'}, ]
|
||||
|
||||
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
|
||||
return response.choices[0].message.content
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue