ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: uwer <uwe.rosebrock@gmail.com> Co-authored-by: Uwe Rosebrock <ro260@csiro.au> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
parent
e76754eab0
commit
783033fa6b
19 changed files with 387 additions and 76 deletions
|
|
@ -1,11 +1,13 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT
|
||||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('streamlit')
|
||||
check_requirements('streamlit>=1.29.0')
|
||||
check_requirements('streamlit-select>=0.2')
|
||||
import streamlit as st
|
||||
from streamlit_select import image_select
|
||||
|
|
@ -35,9 +37,9 @@ def init_explorer_form():
|
|||
with st.form(key='explorer_init_form'):
|
||||
col1, col2 = st.columns(2)
|
||||
with col1:
|
||||
dataset = st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
|
||||
with col2:
|
||||
model = st.selectbox('Select model', models, key='model')
|
||||
st.selectbox('Select model', models, key='model')
|
||||
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
|
||||
|
||||
st.form_submit_button('Explore', on_click=_get_explorer)
|
||||
|
|
@ -47,11 +49,23 @@ def query_form():
|
|||
with st.form('query_form'):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
query = st.text_input('Query', '', label_visibility='collapsed', key='query')
|
||||
st.text_input('Query',
|
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||
label_visibility='collapsed',
|
||||
key='query')
|
||||
with col2:
|
||||
st.form_submit_button('Query', on_click=run_sql_query)
|
||||
|
||||
|
||||
def ai_query_form():
|
||||
with st.form('ai_query_form'):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
|
||||
with col2:
|
||||
st.form_submit_button('Ask AI', on_click=run_ai_query)
|
||||
|
||||
|
||||
def find_similar_imgs(imgs):
|
||||
exp = st.session_state['explorer']
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
|
||||
|
|
@ -64,12 +78,12 @@ def similarity_form(selected_imgs):
|
|||
with st.form('similarity_form'):
|
||||
subcol1, subcol2 = st.columns([1, 1])
|
||||
with subcol1:
|
||||
limit = st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
st.number_input('limit',
|
||||
min_value=None,
|
||||
max_value=None,
|
||||
value=25,
|
||||
label_visibility='collapsed',
|
||||
key='limit')
|
||||
|
||||
with subcol2:
|
||||
disabled = not len(selected_imgs)
|
||||
|
|
@ -95,6 +109,7 @@ def similarity_form(selected_imgs):
|
|||
|
||||
|
||||
def run_sql_query():
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('query')
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
|
|
@ -102,9 +117,26 @@ def run_sql_query():
|
|||
st.session_state['imgs'] = res.to_pydict()['im_file']
|
||||
|
||||
|
||||
def run_ai_query():
|
||||
if not SETTINGS['openai_api_key']:
|
||||
st.session_state[
|
||||
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
return
|
||||
st.session_state['error'] = None
|
||||
query = st.session_state.get('ai_query')
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state['explorer']
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
|
||||
return
|
||||
st.session_state['imgs'] = res['im_file'].to_list()
|
||||
|
||||
|
||||
def reset_explorer():
|
||||
st.session_state['explorer'] = None
|
||||
st.session_state['imgs'] = None
|
||||
st.session_state['error'] = None
|
||||
|
||||
|
||||
def utralytics_explorer_docs_callback():
|
||||
|
|
@ -112,10 +144,10 @@ def utralytics_explorer_docs_callback():
|
|||
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
|
||||
width=100)
|
||||
st.markdown(
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href=''>API docs</a> to try examples & learn more</p>",
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
||||
unsafe_allow_html=True,
|
||||
help=None)
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/')
|
||||
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
|
||||
|
||||
|
||||
def layout():
|
||||
|
|
@ -129,9 +161,12 @@ def layout():
|
|||
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
|
||||
exp = st.session_state.get('explorer')
|
||||
col1, col2 = st.columns([0.75, 0.25], gap='small')
|
||||
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
total_imgs = len(imgs)
|
||||
imgs = []
|
||||
if st.session_state.get('error'):
|
||||
st.error(st.session_state['error'])
|
||||
else:
|
||||
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
|
||||
total_imgs, selected_imgs = len(imgs), []
|
||||
with col1:
|
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||
with subcol1:
|
||||
|
|
@ -159,6 +194,7 @@ def layout():
|
|||
st.experimental_rerun()
|
||||
|
||||
query_form()
|
||||
ai_query_form()
|
||||
if total_imgs:
|
||||
imgs_displayed = imgs[start_idx:start_idx + num]
|
||||
selected_imgs = image_select(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue