ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: uwer <uwe.rosebrock@gmail.com>
Co-authored-by: Uwe Rosebrock <ro260@csiro.au>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
Glenn Jocher 2024-01-08 23:36:29 +01:00 committed by GitHub
parent e76754eab0
commit 783033fa6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 387 additions and 76 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.237'
__version__ = '8.0.238'
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO

View file

@ -0,0 +1,3 @@
from .utils import plot_query_result
__all__ = ['plot_query_result']

View file

@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
class ExplorerDataset(YOLODataset):
@ -58,7 +58,7 @@ class Explorer:
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
@ -112,8 +112,7 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
@ -159,10 +158,7 @@ class Explorer:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
@ -189,16 +185,19 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
)
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
@ -228,7 +227,10 @@ class Explorer:
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
if len(result) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
@ -255,6 +257,8 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
@ -288,7 +292,10 @@ class Explorer:
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
if len(similar) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
@ -299,7 +306,7 @@ class Explorer:
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
vector search. Defaults: None.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
@ -401,6 +408,32 @@ class Explorer:
return img if isinstance(img, list) else [img]
def ask_ai(self, query):
"""
Ask AI a question.
Args:
query (str): Question to ask.
Returns:
Answer from AI.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
```
"""
result = prompt_sql_query(query)
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error(e)
return None
return df
def visualize(self, result):
"""
Visualize the results of a query.

View file

@ -1,11 +1,13 @@
import time
from threading import Thread
import pandas as pd
from ultralytics import Explorer
from ultralytics.utils import ROOT
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements('streamlit')
check_requirements('streamlit>=1.29.0')
check_requirements('streamlit-select>=0.2')
import streamlit as st
from streamlit_select import image_select
@ -35,9 +37,9 @@ def init_explorer_form():
with st.form(key='explorer_init_form'):
col1, col2 = st.columns(2)
with col1:
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
with col2:
model = st.selectbox('Select model', models, key='model')
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.form_submit_button('Explore', on_click=_get_explorer)
@ -47,11 +49,23 @@ def query_form():
with st.form('query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
def ai_query_form():
with st.form('ai_query_form'):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
def find_similar_imgs(imgs):
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
@ -64,12 +78,12 @@ def similarity_form(selected_imgs):
with st.form('similarity_form'):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
limit = st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
with subcol2:
disabled = not len(selected_imgs)
@ -95,6 +109,7 @@ def similarity_form(selected_imgs):
def run_sql_query():
st.session_state['error'] = None
query = st.session_state.get('query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
@ -102,9 +117,26 @@ def run_sql_query():
st.session_state['imgs'] = res.to_pydict()['im_file']
def run_ai_query():
if not SETTINGS['openai_api_key']:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
return
st.session_state['imgs'] = res['im_file'].to_list()
def reset_explorer():
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
def utralytics_explorer_docs_callback():
@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback():
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
def layout():
@ -129,9 +161,12 @@ def layout():
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs = len(imgs)
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
@ -159,6 +194,7 @@ def layout():
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
selected_imgs = image_select(

View file

@ -1,9 +1,14 @@
import getpass
from typing import List
import cv2
import numpy as np
import pandas as pd
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER as logger
from ultralytics.utils import SETTINGS
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.ops import xyxy2xywh
from ultralytics.utils.plotting import plot_images
@ -47,15 +52,16 @@ def sanitize_batch(batch, dataset_info):
return batch
def plot_similar_images(similar_set, plot_labels=True):
def plot_query_result(similar_set, plot_labels=True):
"""
Plot images from the similar set.
Args:
similar_set (list): Pyarrow table containing the similar data points
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_pydict()
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
@ -102,3 +108,61 @@ def plot_similar_images(similar_set, plot_labels=True):
max_subplots=len(images),
save=False,
threaded=False)
def prompt_sql_query(query):
check_requirements('openai>=1.6.1')
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
messages = [
{
'role':
'system',
'content':
'''
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
Schema:
im_file: string not null
labels: list<item: string> not null
child 0, item: string
cls: list<item: int64> not null
child 0, item: int64
bboxes: list<item: list<item: double>> not null
child 0, item: list<item: double>
child 0, item: double
masks: list<item: list<item: list<item: int64>>> not null
child 0, item: list<item: list<item: int64>>
child 0, item: list<item: int64>
child 0, item: int64
keypoints: list<item: list<item: list<item: double>>> not null
child 0, item: list<item: list<item: double>>
child 0, item: list<item: double>
child 0, item: double
vector: fixed_size_list<item: float>[256] not null
child 0, item: float
Some details about the schema:
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
in each image
- the "cls" column contains the integer values on these classes that map them the labels
Example of a correct query:
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
return response.choices[0].message.content

View file

@ -246,7 +246,7 @@ class Model(nn.Module):
prompts = args.pop('prompts', None) # for SAM-type models
if not self.predictor:
self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks)
self.predictor = predictor or self._smart_load('predictor')(overrides=args, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, args)

View file

@ -41,8 +41,7 @@ class OBBPredictor(DetectionPredictor):
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i]
for i, (pred, orig_img) in enumerate(zip(preds, orig_imgs)):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape, xywh=True)
img_path = self.batch[0][i]
# xywh, r, conf, cls

View file

@ -61,13 +61,13 @@ class Detect(nn.Module):
dbox = self.decode_bboxes(box)
if self.export and self.format in ('tflite', 'edgetpu'):
# Normalize xywh with image size to mitigate quantization error of TFLite integer models as done in YOLOv5:
# https://github.com/ultralytics/yolov5/blob/0c8de3fca4a702f8ff5c435e67f378d1fce70243/models/tf.py#L307-L309
# See this PR for details: https://github.com/ultralytics/ultralytics/pull/1695
img_h = shape[2] * self.stride[0]
img_w = shape[3] * self.stride[0]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
dbox /= img_size
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
img_h = shape[2]
img_w = shape[3]
img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * img_size)
dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1)
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)

View file

@ -4,6 +4,7 @@ import math
import cv2
from ultralytics.utils.checks import check_imshow
from ultralytics.utils.plotting import Annotator, colors
@ -37,6 +38,9 @@ class DistanceCalculation:
self.left_mouse_count = 0
self.selected_boxes = {}
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
names,
pixels_per_meter=10,
@ -168,7 +172,7 @@ class DistanceCalculation:
self.centroids = []
if self.view_img:
if self.view_img and self.env_check:
self.display_frames()
return im0

View file

@ -28,6 +28,8 @@ class Heatmap:
self.imw = None
self.imh = None
self.im0 = None
self.view_in_counts = True
self.view_out_counts = True
# Heatmap colormap and heatmap np array
self.colormap = None
@ -67,6 +69,8 @@ class Heatmap:
colormap=cv2.COLORMAP_JET,
heatmap_alpha=0.5,
view_img=False,
view_in_counts=True,
view_out_counts=True,
count_reg_pts=None,
count_txt_thickness=2,
count_txt_color=(0, 0, 0),
@ -85,6 +89,8 @@ class Heatmap:
imh (int): The height of the frame.
heatmap_alpha (float): alpha value for heatmap display
view_img (bool): Flag indicating frame display
view_in_counts (bool): Flag to control whether to display the incounts on video stream.
view_out_counts (bool): Flag to control whether to display the outcounts on video stream.
count_reg_pts (list): Object counting region points
count_txt_thickness (int): Text thickness for object counting display
count_txt_color (RGB color): count text color value
@ -99,6 +105,8 @@ class Heatmap:
self.imh = imh
self.heatmap_alpha = heatmap_alpha
self.view_img = view_img
self.view_in_counts = view_in_counts
self.view_out_counts = view_out_counts
self.colormap = colormap
# Region and line selection
@ -171,9 +179,10 @@ class Heatmap:
if self.count_reg_pts is not None:
# Draw counting region
self.annotator.draw_region(reg_pts=self.count_reg_pts,
color=self.region_color,
thickness=self.region_thickness)
if self.view_in_counts or self.view_out_counts:
self.annotator.draw_region(reg_pts=self.count_reg_pts,
color=self.region_color,
thickness=self.region_thickness)
for box, cls, track_id in zip(self.boxes, self.clss, self.track_ids):
@ -235,11 +244,22 @@ class Heatmap:
heatmap_normalized = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX)
heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), self.colormap)
if self.count_reg_pts is not None:
incount_label = 'InCount : ' + f'{self.in_counts}'
outcount_label = 'OutCount : ' + f'{self.out_counts}'
self.annotator.count_labels(in_count=incount_label,
out_count=outcount_label,
incount_label = 'In Count : ' + f'{self.in_counts}'
outcount_label = 'OutCount : ' + f'{self.out_counts}'
# Display counts based on user choice
counts_label = None
if not self.view_in_counts and not self.view_out_counts:
counts_label = None
elif not self.view_in_counts:
counts_label = outcount_label
elif not self.view_out_counts:
counts_label = incount_label
else:
counts_label = incount_label + ' ' + outcount_label
if self.count_reg_pts is not None and counts_label is not None:
self.annotator.count_labels(counts=counts_label,
count_txt_size=self.count_txt_thickness,
txt_color=self.count_txt_color,
color=self.count_color)

View file

@ -856,6 +856,7 @@ class SettingsManager(dict):
'uuid': hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(),
'sync': True,
'api_key': '',
'openai_api_key': '',
'clearml': True, # integrations
'comet': True,
'dvc': True,