ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2024-01-10 03:16:08 +01:00 committed by GitHub
parent e795277391
commit fe27db2f6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
139 changed files with 6870 additions and 5125 deletions

View file

@ -2,4 +2,4 @@
from .utils import plot_query_result
__all__ = ['plot_query_result']
__all__ = ["plot_query_result"]

View file

@ -22,7 +22,6 @@ from .utils import get_sim_index_schema, get_table_schema, plot_query_result, pr
class ExplorerDataset(YOLODataset):
def __init__(self, *args, data: dict = None, **kwargs) -> None:
super().__init__(*args, data=data, **kwargs)
@ -35,7 +34,7 @@ class ExplorerDataset(YOLODataset):
else: # read image
im = cv2.imread(f) # BGR
if im is None:
raise FileNotFoundError(f'Image Not Found {f}')
raise FileNotFoundError(f"Image Not Found {f}")
h0, w0 = im.shape[:2] # orig hw
return im, (h0, w0), im.shape[:2]
@ -44,7 +43,7 @@ class ExplorerDataset(YOLODataset):
def build_transforms(self, hyp: IterableSimpleNamespace = None):
"""Creates transforms for dataset images without resizing."""
return Format(
bbox_format='xyxy',
bbox_format="xyxy",
normalize=False,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
@ -55,17 +54,16 @@ class ExplorerDataset(YOLODataset):
class Explorer:
def __init__(self,
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
def __init__(
self, data: Union[str, Path] = "coco128.yaml", model: str = "yolov8n.pt", uri: str = "~/ultralytics/explorer"
) -> None:
checks.check_requirements(["lancedb>=0.4.3", "duckdb"])
import lancedb
self.connection = lancedb.connect(uri)
self.table_name = Path(data).name.lower() + '_' + model.lower()
self.sim_idx_base_name = f'{self.table_name}_sim_idx'.lower(
self.table_name = Path(data).name.lower() + "_" + model.lower()
self.sim_idx_base_name = (
f"{self.table_name}_sim_idx".lower()
) # Use this name and append thres and top_k to reuse the table
self.model = YOLO(model)
self.data = data # None
@ -74,7 +72,7 @@ class Explorer:
self.table = None
self.progress = 0
def create_embeddings_table(self, force: bool = False, split: str = 'train') -> None:
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
"""
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
already exists. Pass force=True to overwrite the existing table.
@ -90,20 +88,20 @@ class Explorer:
```
"""
if self.table is not None and not force:
LOGGER.info('Table already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
return
if self.table_name in self.connection.table_names() and not force:
LOGGER.info(f'Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
self.table = self.connection.open_table(self.table_name)
self.progress = 1
return
if self.data is None:
raise ValueError('Data must be provided to create embeddings table')
raise ValueError("Data must be provided to create embeddings table")
data_info = check_det_dataset(self.data)
if split not in data_info:
raise ValueError(
f'Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}'
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
)
choice_set = data_info[split]
@ -113,13 +111,16 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
table.add(
self._yield_batches(dataset,
data_info,
self.model,
exclude_keys=['img', 'ratio_pad', 'resized_shape', 'ori_shape', 'batch_idx']))
self._yield_batches(
dataset,
data_info,
self.model,
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
)
)
self.table = table
@ -131,12 +132,12 @@ class Explorer:
for k in exclude_keys:
batch.pop(k, None)
batch = sanitize_batch(batch, data_info)
batch['vector'] = model.embed(batch['im_file'], verbose=False)[0].detach().tolist()
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
yield [batch]
def query(self,
imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
limit: int = 25) -> Any: # pyarrow.Table
def query(
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
) -> Any: # pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -157,18 +158,18 @@ class Explorer:
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
raise ValueError("Table is not created. Please create the table first.")
if isinstance(imgs, str):
imgs = [imgs]
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
return self.table.search(embeds).limit(limit).to_arrow()
def sql_query(self,
query: str,
return_type: str = 'pandas') -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
def sql_query(
self, query: str, return_type: str = "pandas"
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
"""
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
@ -187,27 +188,29 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
assert return_type in [
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
raise ValueError("Table is not created. Please create the table first.")
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
if not query.startswith("SELECT") and not query.startswith("WHERE"):
raise ValueError(
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}"
)
if query.startswith('WHERE'):
if query.startswith("WHERE"):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
LOGGER.info(f"Running query: {query}")
rs = duckdb.sql(query)
if return_type == 'pandas':
if return_type == "pandas":
return rs.df()
elif return_type == 'arrow':
elif return_type == "arrow":
return rs.arrow()
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
@ -228,18 +231,20 @@ class Explorer:
result = exp.plot_sql_query(query)
```
"""
result = self.sql_query(query, return_type='arrow')
result = self.sql_query(query, return_type="arrow")
if len(result) == 0:
LOGGER.info('No results found.')
LOGGER.info("No results found.")
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = 'pandas') -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
def get_similar(
self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
return_type: str = "pandas",
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
"""
Query the table for similar images. Accepts a single image or a list of images.
@ -259,21 +264,25 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
assert return_type in [
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == 'pandas':
if return_type == "pandas":
return similar.to_pandas()
elif return_type == 'arrow':
elif return_type == "arrow":
return similar
def plot_similar(self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True) -> Image.Image:
def plot_similar(
self,
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
idx: Union[int, List[int]] = None,
limit: int = 25,
labels: bool = True,
) -> Image.Image:
"""
Plot the similar images. Accepts images or indexes.
@ -293,9 +302,9 @@ class Explorer:
similar = exp.plot_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
similar = self.get_similar(img, idx, limit, return_type="arrow")
if len(similar) == 0:
LOGGER.info('No results found.')
LOGGER.info("No results found.")
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
@ -323,34 +332,37 @@ class Explorer:
```
"""
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
sim_idx_table_name = f'{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}'.lower()
raise ValueError("Table is not created. Please create the table first.")
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
if sim_idx_table_name in self.connection.table_names() and not force:
LOGGER.info('Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.')
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
return self.connection.open_table(sim_idx_table_name).to_pandas()
if top_k and not (1.0 >= top_k >= 0.0):
raise ValueError(f'top_k must be between 0.0 and 1.0. Got {top_k}')
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
if max_dist < 0.0:
raise ValueError(f'max_dist must be greater than 0. Got {max_dist}')
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
top_k = max(top_k, 1)
features = self.table.to_lance().to_table(columns=['vector', 'im_file']).to_pydict()
im_files = features['im_file']
embeddings = features['vector']
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
im_files = features["im_file"]
embeddings = features["vector"]
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode='overwrite')
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
def _yield_sim_idx():
"""Generates a dataframe with similarity indices and distances for images."""
for i in tqdm(range(len(embeddings))):
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f'_distance <= {max_dist}')
yield [{
'idx': i,
'im_file': im_files[i],
'count': len(sim_idx),
'sim_im_files': sim_idx['im_file'].tolist()}]
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
yield [
{
"idx": i,
"im_file": im_files[i],
"count": len(sim_idx),
"sim_im_files": sim_idx["im_file"].tolist(),
}
]
sim_table.add(_yield_sim_idx())
self.sim_index = sim_table
@ -381,7 +393,7 @@ class Explorer:
```
"""
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
sim_count = sim_idx['count'].tolist()
sim_count = sim_idx["count"].tolist()
sim_count = np.array(sim_count)
indices = np.arange(len(sim_count))
@ -390,25 +402,26 @@ class Explorer:
plt.bar(indices, sim_count)
# Customize the plot (optional)
plt.xlabel('data idx')
plt.ylabel('Count')
plt.title('Similarity Count')
plt.xlabel("data idx")
plt.ylabel("Count")
plt.title("Similarity Count")
buffer = BytesIO()
plt.savefig(buffer, format='png')
plt.savefig(buffer, format="png")
buffer.seek(0)
# Use Pillow to open the image from the buffer
return Image.fromarray(np.array(Image.open(buffer)))
def _check_imgs_or_idxs(self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None],
idx: Union[None, int, List[int]]) -> List[np.ndarray]:
def _check_imgs_or_idxs(
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
) -> List[np.ndarray]:
if img is None and idx is None:
raise ValueError('Either img or idx must be provided.')
raise ValueError("Either img or idx must be provided.")
if img is not None and idx is not None:
raise ValueError('Only one of img or idx must be provided.')
raise ValueError("Only one of img or idx must be provided.")
if idx is not None:
idx = idx if isinstance(idx, list) else [idx]
img = self.table.to_lance().take(idx, columns=['im_file']).to_pydict()['im_file']
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
return img if isinstance(img, list) else [img]
@ -433,7 +446,7 @@ class Explorer:
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
LOGGER.error(e)
return None
return df

View file

@ -9,100 +9,114 @@ from ultralytics import Explorer
from ultralytics.utils import ROOT, SETTINGS
from ultralytics.utils.checks import check_requirements
check_requirements(('streamlit>=1.29.0', 'streamlit-select>=0.2'))
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.2"))
import streamlit as st
from streamlit_select import image_select
def _get_explorer():
"""Initializes and returns an instance of the Explorer class."""
exp = Explorer(data=st.session_state.get('dataset'), model=st.session_state.get('model'))
thread = Thread(target=exp.create_embeddings_table,
kwargs={'force': st.session_state.get('force_recreate_embeddings')})
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
thread = Thread(
target=exp.create_embeddings_table, kwargs={"force": st.session_state.get("force_recreate_embeddings")}
)
thread.start()
progress_bar = st.progress(0, text='Creating embeddings table...')
progress_bar = st.progress(0, text="Creating embeddings table...")
while exp.progress < 1:
time.sleep(0.1)
progress_bar.progress(exp.progress, text=f'Progress: {exp.progress * 100}%')
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
thread.join()
st.session_state['explorer'] = exp
st.session_state["explorer"] = exp
progress_bar.empty()
def init_explorer_form():
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
datasets = ROOT / 'cfg' / 'datasets'
ds = [d.name for d in datasets.glob('*.yaml')]
datasets = ROOT / "cfg" / "datasets"
ds = [d.name for d in datasets.glob("*.yaml")]
models = [
'yolov8n.pt', 'yolov8s.pt', 'yolov8m.pt', 'yolov8l.pt', 'yolov8x.pt', 'yolov8n-seg.pt', 'yolov8s-seg.pt',
'yolov8m-seg.pt', 'yolov8l-seg.pt', 'yolov8x-seg.pt', 'yolov8n-pose.pt', 'yolov8s-pose.pt', 'yolov8m-pose.pt',
'yolov8l-pose.pt', 'yolov8x-pose.pt']
with st.form(key='explorer_init_form'):
"yolov8n.pt",
"yolov8s.pt",
"yolov8m.pt",
"yolov8l.pt",
"yolov8x.pt",
"yolov8n-seg.pt",
"yolov8s-seg.pt",
"yolov8m-seg.pt",
"yolov8l-seg.pt",
"yolov8x-seg.pt",
"yolov8n-pose.pt",
"yolov8s-pose.pt",
"yolov8m-pose.pt",
"yolov8l-pose.pt",
"yolov8x-pose.pt",
]
with st.form(key="explorer_init_form"):
col1, col2 = st.columns(2)
with col1:
st.selectbox('Select dataset', ds, key='dataset', index=ds.index('coco128.yaml'))
st.selectbox("Select dataset", ds, key="dataset", index=ds.index("coco128.yaml"))
with col2:
st.selectbox('Select model', models, key='model')
st.checkbox('Force recreate embeddings', key='force_recreate_embeddings')
st.selectbox("Select model", models, key="model")
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
st.form_submit_button('Explore', on_click=_get_explorer)
st.form_submit_button("Explore", on_click=_get_explorer)
def query_form():
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
with st.form('query_form'):
with st.form("query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query',
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility='collapsed',
key='query')
st.text_input(
"Query",
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
label_visibility="collapsed",
key="query",
)
with col2:
st.form_submit_button('Query', on_click=run_sql_query)
st.form_submit_button("Query", on_click=run_sql_query)
def ai_query_form():
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
with st.form('ai_query_form'):
with st.form("ai_query_form"):
col1, col2 = st.columns([0.8, 0.2])
with col1:
st.text_input('Query', 'Show images with 1 person and 1 dog', label_visibility='collapsed', key='ai_query')
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
with col2:
st.form_submit_button('Ask AI', on_click=run_ai_query)
st.form_submit_button("Ask AI", on_click=run_ai_query)
def find_similar_imgs(imgs):
"""Initializes a Streamlit form for AI-based image querying with custom input."""
exp = st.session_state['explorer']
similar = exp.get_similar(img=imgs, limit=st.session_state.get('limit'), return_type='arrow')
paths = similar.to_pydict()['im_file']
st.session_state['imgs'] = paths
exp = st.session_state["explorer"]
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
paths = similar.to_pydict()["im_file"]
st.session_state["imgs"] = paths
def similarity_form(selected_imgs):
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
st.write('Similarity Search')
with st.form('similarity_form'):
st.write("Similarity Search")
with st.form("similarity_form"):
subcol1, subcol2 = st.columns([1, 1])
with subcol1:
st.number_input('limit',
min_value=None,
max_value=None,
value=25,
label_visibility='collapsed',
key='limit')
st.number_input(
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
)
with subcol2:
disabled = not len(selected_imgs)
st.write('Selected: ', len(selected_imgs))
st.write("Selected: ", len(selected_imgs))
st.form_submit_button(
'Search',
"Search",
disabled=disabled,
on_click=find_similar_imgs,
args=(selected_imgs, ),
args=(selected_imgs,),
)
if disabled:
st.error('Select at least one image to search.')
st.error("Select at least one image to search.")
# def persist_reset_form():
@ -117,100 +131,108 @@ def similarity_form(selected_imgs):
def run_sql_query():
"""Executes an SQL query and returns the results."""
st.session_state['error'] = None
query = st.session_state.get('query')
st.session_state["error"] = None
query = st.session_state.get("query")
if query.rstrip().lstrip():
exp = st.session_state['explorer']
res = exp.sql_query(query, return_type='arrow')
st.session_state['imgs'] = res.to_pydict()['im_file']
exp = st.session_state["explorer"]
res = exp.sql_query(query, return_type="arrow")
st.session_state["imgs"] = res.to_pydict()["im_file"]
def run_ai_query():
"""Execute SQL query and update session state with query results."""
if not SETTINGS['openai_api_key']:
if not SETTINGS["openai_api_key"]:
st.session_state[
'error'] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
"error"
] = 'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
return
st.session_state['error'] = None
query = st.session_state.get('ai_query')
st.session_state["error"] = None
query = st.session_state.get("ai_query")
if query.rstrip().lstrip():
exp = st.session_state['explorer']
exp = st.session_state["explorer"]
res = exp.ask_ai(query)
if not isinstance(res, pd.DataFrame) or res.empty:
st.session_state['error'] = 'No results found using AI generated query. Try another query or rerun it.'
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
return
st.session_state['imgs'] = res['im_file'].to_list()
st.session_state["imgs"] = res["im_file"].to_list()
def reset_explorer():
"""Resets the explorer to its initial state by clearing session variables."""
st.session_state['explorer'] = None
st.session_state['imgs'] = None
st.session_state['error'] = None
st.session_state["explorer"] = None
st.session_state["imgs"] = None
st.session_state["error"] = None
def utralytics_explorer_docs_callback():
"""Resets the explorer to its initial state by clearing session variables."""
with st.container(border=True):
st.image('https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg',
width=100)
st.image(
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
width=100,
)
st.markdown(
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
unsafe_allow_html=True,
help=None)
st.link_button('Ultrlaytics Explorer API', 'https://docs.ultralytics.com/datasets/explorer/')
help=None,
)
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
def layout():
"""Resets explorer session variables and provides documentation with a link to API docs."""
st.set_page_config(layout='wide', initial_sidebar_state='collapsed')
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
if st.session_state.get('explorer') is None:
if st.session_state.get("explorer") is None:
init_explorer_form()
return
st.button(':arrow_backward: Select Dataset', on_click=reset_explorer)
exp = st.session_state.get('explorer')
col1, col2 = st.columns([0.75, 0.25], gap='small')
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
exp = st.session_state.get("explorer")
col1, col2 = st.columns([0.75, 0.25], gap="small")
imgs = []
if st.session_state.get('error'):
st.error(st.session_state['error'])
if st.session_state.get("error"):
st.error(st.session_state["error"])
else:
imgs = st.session_state.get('imgs') or exp.table.to_lance().to_table(columns=['im_file']).to_pydict()['im_file']
imgs = st.session_state.get("imgs") or exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
total_imgs, selected_imgs = len(imgs), []
with col1:
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
with subcol1:
st.write('Max Images Displayed:')
st.write("Max Images Displayed:")
with subcol2:
num = st.number_input('Max Images Displayed',
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key='num_imgs_displayed',
label_visibility='collapsed')
num = st.number_input(
"Max Images Displayed",
min_value=0,
max_value=total_imgs,
value=min(500, total_imgs),
key="num_imgs_displayed",
label_visibility="collapsed",
)
with subcol3:
st.write('Start Index:')
st.write("Start Index:")
with subcol4:
start_idx = st.number_input('Start Index',
min_value=0,
max_value=total_imgs,
value=0,
key='start_index',
label_visibility='collapsed')
start_idx = st.number_input(
"Start Index",
min_value=0,
max_value=total_imgs,
value=0,
key="start_index",
label_visibility="collapsed",
)
with subcol5:
reset = st.button('Reset', use_container_width=False, key='reset')
reset = st.button("Reset", use_container_width=False, key="reset")
if reset:
st.session_state['imgs'] = None
st.session_state["imgs"] = None
st.experimental_rerun()
query_form()
ai_query_form()
if total_imgs:
imgs_displayed = imgs[start_idx:start_idx + num]
imgs_displayed = imgs[start_idx : start_idx + num]
selected_imgs = image_select(
f'Total samples: {total_imgs}',
f"Total samples: {total_imgs}",
images=imgs_displayed,
use_container_width=False,
# indices=[i for i in range(num)] if select_all else None,
@ -222,5 +244,5 @@ def layout():
utralytics_explorer_docs_callback()
if __name__ == '__main__':
if __name__ == "__main__":
layout()

View file

@ -46,14 +46,13 @@ def get_sim_index_schema():
def sanitize_batch(batch, dataset_info):
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
batch['cls'] = batch['cls'].flatten().int().tolist()
box_cls_pair = sorted(zip(batch['bboxes'].tolist(), batch['cls']), key=lambda x: x[1])
batch['bboxes'] = [box for box, _ in box_cls_pair]
batch['cls'] = [cls for _, cls in box_cls_pair]
batch['labels'] = [dataset_info['names'][i] for i in batch['cls']]
batch['masks'] = batch['masks'].tolist() if 'masks' in batch else [[[]]]
batch['keypoints'] = batch['keypoints'].tolist() if 'keypoints' in batch else [[[]]]
batch["cls"] = batch["cls"].flatten().int().tolist()
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
batch["bboxes"] = [box for box, _ in box_cls_pair]
batch["cls"] = [cls for _, cls in box_cls_pair]
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
return batch
@ -65,15 +64,16 @@ def plot_query_result(similar_set, plot_labels=True):
similar_set (list): Pyarrow or pandas object containing the similar data points
plot_labels (bool): Whether to plot labels or not
"""
similar_set = similar_set.to_dict(
orient='list') if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
similar_set = (
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
)
empty_masks = [[[]]]
empty_boxes = [[]]
images = similar_set.get('im_file', [])
bboxes = similar_set.get('bboxes', []) if similar_set.get('bboxes') is not empty_boxes else []
masks = similar_set.get('masks') if similar_set.get('masks')[0] != empty_masks else []
kpts = similar_set.get('keypoints') if similar_set.get('keypoints')[0] != empty_masks else []
cls = similar_set.get('cls', [])
images = similar_set.get("im_file", [])
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
cls = similar_set.get("cls", [])
plot_size = 640
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
@ -104,34 +104,26 @@ def plot_query_result(similar_set, plot_labels=True):
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
return plot_images(imgs,
batch_idx,
cls,
bboxes=boxes,
masks=masks,
kpts=kpts,
max_subplots=len(images),
save=False,
threaded=False)
return plot_images(
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
)
def prompt_sql_query(query):
"""Plots images with optional labels from a similar data set."""
check_requirements('openai>=1.6.1')
check_requirements("openai>=1.6.1")
from openai import OpenAI
if not SETTINGS['openai_api_key']:
logger.warning('OpenAI API key not found in settings. Please enter your API key below.')
openai_api_key = getpass.getpass('OpenAI API key: ')
SETTINGS.update({'openai_api_key': openai_api_key})
openai = OpenAI(api_key=SETTINGS['openai_api_key'])
if not SETTINGS["openai_api_key"]:
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
openai_api_key = getpass.getpass("OpenAI API key: ")
SETTINGS.update({"openai_api_key": openai_api_key})
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
messages = [
{
'role':
'system',
'content':
'''
"role": "system",
"content": """
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
the following schema and a user request. You only need to output the format with fixed selection
statement that selects everything from "'table'", like `SELECT * from 'table'`
@ -165,10 +157,10 @@ def prompt_sql_query(query):
request - Get all data points that contain 2 or more people and at least one dog
correct query-
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
'''},
{
'role': 'user',
'content': f'{query}'}, ]
""",
},
{"role": "user", "content": f"{query}"},
]
response = openai.chat.completions.create(model='gpt-3.5-turbo', messages=messages)
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
return response.choices[0].message.content