Remove explorer Integration (#16842)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1b52e5e693
commit
54d8801dfb
20 changed files with 36 additions and 1118 deletions
|
|
@ -6,6 +6,10 @@ keywords: Ultralytics, Explorer API, dataset exploration, SQL queries, similarit
|
|||
|
||||
# Ultralytics Explorer API
|
||||
|
||||
!!! warning "Community Note ⚠️"
|
||||
|
||||
As of **`ultralytics>=8.3.10`**, Ultralytics explorer support has been deprecated. But don't worry! You can now access similar and even enhanced functionality through [Ultralytics HUB](https://hub.ultralytics.com/), our intuitive no-code platform designed to streamline your workflow. With Ultralytics HUB, you can continue exploring, visualizing, and managing your data effortlessly, all without writing a single line of code. Make sure to check it out and take advantage of its powerful features!🚀
|
||||
|
||||
## Introduction
|
||||
|
||||
<a href="https://colab.research.google.com/github/ultralytics/ultralytics/blob/main/docs/en/datasets/explorer/explorer.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ keywords: Ultralytics Explorer GUI, semantic search, vector similarity, SQL quer
|
|||
|
||||
# Explorer GUI
|
||||
|
||||
!!! warning "Community Note ⚠️"
|
||||
|
||||
As of **`ultralytics>=8.3.10`**, Ultralytics explorer support has been deprecated. But don't worry! You can now access similar and even enhanced functionality through [Ultralytics HUB](https://hub.ultralytics.com/), our intuitive no-code platform designed to streamline your workflow. With Ultralytics HUB, you can continue exploring, visualizing, and managing your data effortlessly, all without writing a single line of code. Make sure to check it out and take advantage of its powerful features!🚀
|
||||
|
||||
Explorer GUI is like a playground build using [Ultralytics Explorer API](api.md). It allows you to run semantic/vector similarity search, SQL queries and even search using natural language using our ask AI feature powered by LLMs.
|
||||
|
||||
<p>
|
||||
|
|
|
|||
|
|
@ -30,6 +30,18 @@
|
|||
"</div>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Ultralytics Explorer support deprecated ⚠️\n",
|
||||
"\n",
|
||||
"As of **`ultralytics>=8.3.10`**, Ultralytics explorer support has been deprecated. But don’t worry! You can now access similar and even enhanced functionality through [Ultralytics HUB](https://hub.ultralytics.com/), our intuitive no-code platform designed to streamline your workflow. With Ultralytics HUB, you can continue exploring, visualizing, and managing your data effortlessly, all without writing a single line of code. Make sure to check it out and take advantage of its powerful features!🚀"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "RHe1PX5c7uK2"
|
||||
},
|
||||
"id": "RHe1PX5c7uK2"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2454d9ba-9db4-4b37-98e8-201ba285c92f",
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ keywords: Ultralytics Explorer, CV datasets, semantic search, SQL queries, vecto
|
|||
|
||||
# Ultralytics Explorer
|
||||
|
||||
!!! warning "Community Note ⚠️"
|
||||
|
||||
As of **`ultralytics>=8.3.10`**, Ultralytics explorer support has been deprecated. But don't worry! You can now access similar and even enhanced functionality through [Ultralytics HUB](https://hub.ultralytics.com/), our intuitive no-code platform designed to streamline your workflow. With Ultralytics HUB, you can continue exploring, visualizing, and managing your data effortlessly, all without writing a single line of code. Make sure to check it out and take advantage of its powerful features!🚀
|
||||
|
||||
<p>
|
||||
<img width="1709" alt="Ultralytics Explorer Screenshot 1" src="https://github.com/ultralytics/docs/releases/download/0/explorer-dashboard-screenshot-1.avif">
|
||||
</p>
|
||||
|
|
|
|||
|
|
@ -120,6 +120,10 @@ Common tools for visualizations include:
|
|||
|
||||
### Using Ultralytics Explorer for EDA
|
||||
|
||||
!!! warning "Community Note ⚠️"
|
||||
|
||||
As of **`ultralytics>=8.3.10`**, Ultralytics explorer support has been deprecated. But don't worry! You can now access similar and even enhanced functionality through [Ultralytics HUB](https://hub.ultralytics.com/), our intuitive no-code platform designed to streamline your workflow. With Ultralytics HUB, you can continue exploring, visualizing, and managing your data effortlessly, all without writing a single line of code. Make sure to check it out and take advantage of its powerful features!🚀
|
||||
|
||||
For a more advanced approach to EDA, you can use the Ultralytics Explorer tool. It offers robust capabilities for exploring computer vision datasets. By supporting semantic search, SQL queries, and vector similarity search, the tool makes it easy to analyze and understand your data. With Ultralytics Explorer, you can create [embeddings](https://www.ultralytics.com/glossary/embeddings) for your dataset to find similar images, run SQL queries for detailed analysis, and perform semantic searches, all through a user-friendly graphical interface.
|
||||
|
||||
<p align="center">
|
||||
|
|
|
|||
|
|
@ -47,10 +47,6 @@ keywords: Ultralytics, YOLO, configuration, cfg2dict, get_cfg, check_cfg, save_d
|
|||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.cfg.handle_explorer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.cfg.handle_streamlit_inference
|
||||
|
||||
<br><br><hr><br>
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
---
|
||||
comments: true
|
||||
description: Explore the Ultralytics data explorer functions including YOLO dataset handling, image querying, embedding generation, and similarity indexing.
|
||||
keywords: Ultralytics, YOLO, data explorer, image querying, embeddings, similarity index, python, machine learning
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/data/explorer/explorer.py`
|
||||
|
||||
!!! note
|
||||
|
||||
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/explorer.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/explorer.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/data/explorer/explorer.py) 🛠️. Thank you 🙏!
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.data.explorer.explorer.ExplorerDataset
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.explorer.Explorer
|
||||
|
||||
<br><br>
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
---
|
||||
comments: true
|
||||
description: Explore the functionalities of Ultralytics Explorer with our comprehensive GUI dash documentation.
|
||||
keywords: Ultralytics, Explorer, GUI, dash, documentation, data explorer, AI query, SQL query, image similarity
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/data/explorer/gui/dash.py`
|
||||
|
||||
!!! note
|
||||
|
||||
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/gui/dash.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/gui/dash.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/data/explorer/gui/dash.py) 🛠️. Thank you 🙏!
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash._get_explorer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.init_explorer_form
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.query_form
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.ai_query_form
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.find_similar_imgs
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.similarity_form
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.run_sql_query
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.run_ai_query
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.reset_explorer
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.utralytics_explorer_docs_callback
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.gui.dash.layout
|
||||
|
||||
<br><br>
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
---
|
||||
comments: true
|
||||
description: Explore various utility functions in ultralytics.data.explorer.utils including schema definitions, batch sanitization, and query results plotting.
|
||||
keywords: Ultralytics, data explorer, utils, schema, sanitize batch, plot query results, SQL query, machine learning
|
||||
---
|
||||
|
||||
# Reference for `ultralytics/data/explorer/utils.py`
|
||||
|
||||
!!! note
|
||||
|
||||
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/utils.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/data/explorer/utils.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/data/explorer/utils.py) 🛠️. Thank you 🙏!
|
||||
|
||||
<br>
|
||||
|
||||
## ::: ultralytics.data.explorer.utils.get_table_schema
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.utils.get_sim_index_schema
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.utils.sanitize_batch
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.utils.plot_query_result
|
||||
|
||||
<br><br><hr><br>
|
||||
|
||||
## ::: ultralytics.data.explorer.utils.prompt_sql_query
|
||||
|
||||
<br><br>
|
||||
|
|
@ -256,50 +256,6 @@ Benchmark mode is used to profile the speed and accuracy of various export forma
|
|||
|
||||
[Benchmark Examples](../modes/benchmark.md){ .md-button }
|
||||
|
||||
## Explorer
|
||||
|
||||
Explorer API can be used to explore datasets with advanced semantic, vector-similarity and SQL search among other features. It also enabled searching for images based on their content using natural language by utilizing the power of LLMs. The Explorer API allows you to write your own dataset exploration notebooks or scripts to get insights into your datasets.
|
||||
|
||||
!!! example "Semantic Search Using Explorer"
|
||||
|
||||
=== "Using Images"
|
||||
|
||||
```python
|
||||
from ultralytics import Explorer
|
||||
|
||||
# create an Explorer object
|
||||
exp = Explorer(data="coco8.yaml", model="yolo11n.pt")
|
||||
exp.create_embeddings_table()
|
||||
|
||||
similar = exp.get_similar(img="https://ultralytics.com/images/bus.jpg", limit=10)
|
||||
print(similar.head())
|
||||
|
||||
# Search using multiple indices
|
||||
similar = exp.get_similar(
|
||||
img=["https://ultralytics.com/images/bus.jpg", "https://ultralytics.com/images/bus.jpg"], limit=10
|
||||
)
|
||||
print(similar.head())
|
||||
```
|
||||
|
||||
=== "Using Dataset Indices"
|
||||
|
||||
```python
|
||||
from ultralytics import Explorer
|
||||
|
||||
# create an Explorer object
|
||||
exp = Explorer(data="coco8.yaml", model="yolo11n.pt")
|
||||
exp.create_embeddings_table()
|
||||
|
||||
similar = exp.get_similar(idx=1, limit=10)
|
||||
print(similar.head())
|
||||
|
||||
# Search using multiple indices
|
||||
similar = exp.get_similar(idx=[1, 10], limit=10)
|
||||
print(similar.head())
|
||||
```
|
||||
|
||||
[Explorer](../datasets/explorer/index.md){ .md-button }
|
||||
|
||||
## Using Trainers
|
||||
|
||||
`YOLO` model class is a high-level wrapper on the Trainer classes. Each YOLO task has its own trainer that inherits from `BaseTrainer`.
|
||||
|
|
|
|||
|
|
@ -25,10 +25,6 @@ The `ultralytics` package comes with a myriad of utilities that can support, enh
|
|||
|
||||
## Data
|
||||
|
||||
### YOLO Data Explorer
|
||||
|
||||
[YOLO Explorer](../datasets/explorer/index.md) was added in the `8.1.0` anniversary update and is a powerful tool you can use to better understand your dataset. One of the key functions that YOLO Explorer provides, is the ability to use text queries to find object instances in your dataset.
|
||||
|
||||
### Auto Labeling / Annotations
|
||||
|
||||
Dataset annotation is a very resource intensive and time-consuming process. If you have a YOLO [object detection](https://www.ultralytics.com/glossary/object-detection) model trained on a reasonable amount of data, you can use it and [SAM](../models/sam.md) to auto-annotate additional data (segmentation format).
|
||||
|
|
|
|||
15
mkdocs.yml
15
mkdocs.yml
|
|
@ -162,8 +162,6 @@ nav:
|
|||
- solutions/index.md
|
||||
- Guides:
|
||||
- guides/index.md
|
||||
- Explorer:
|
||||
- datasets/explorer/index.md
|
||||
- Live Inference 🚀 NEW: guides/streamlit-live-inference.md # for promotion of new pages
|
||||
- Languages:
|
||||
- 🇬🇧  English: https://ultralytics.com/docs/
|
||||
|
|
@ -261,11 +259,6 @@ nav:
|
|||
- YOLO-World (Real-Time Open-Vocabulary Object Detection): models/yolo-world.md
|
||||
- Datasets:
|
||||
- datasets/index.md
|
||||
- Explorer:
|
||||
- datasets/explorer/index.md
|
||||
- Explorer API: datasets/explorer/api.md
|
||||
- Explorer Dashboard: datasets/explorer/dashboard.md
|
||||
- VOC Exploration Example: datasets/explorer/explorer.ipynb
|
||||
- Detection:
|
||||
- datasets/detect/index.md
|
||||
- Argoverse: datasets/detect/argoverse.md
|
||||
|
|
@ -476,11 +469,6 @@ nav:
|
|||
- build: reference/data/build.md
|
||||
- converter: reference/data/converter.md
|
||||
- dataset: reference/data/dataset.md
|
||||
- explorer:
|
||||
- explorer: reference/data/explorer/explorer.md
|
||||
- gui:
|
||||
- dash: reference/data/explorer/gui/dash.md
|
||||
- utils: reference/data/explorer/utils.md
|
||||
- loaders: reference/data/loaders.md
|
||||
- split_dota: reference/data/split_dota.md
|
||||
- utils: reference/data/utils.md
|
||||
|
|
@ -761,3 +749,6 @@ plugins:
|
|||
yolov5/environments/yolov5_amazon_web_services_quickstart_tutorial.md: yolov5/environments/aws_quickstart_tutorial.md
|
||||
yolov5/environments/yolov5_google_cloud_platform_quickstart_tutorial.md: yolov5/environments/google_cloud_quickstart_tutorial.md
|
||||
yolov5/environments/yolov5_docker_image_quickstart_tutorial.md: yolov5/environments/docker_image_quickstart_tutorial.md
|
||||
reference/data/explorer/explorer.md: datasets/explorer/index.md
|
||||
reference/data/explorer/gui/dash.md: datasets/explorer/index.md
|
||||
reference/data/explorer/utils.md: datasets/explorer/index.md
|
||||
|
|
|
|||
|
|
@ -107,10 +107,9 @@ export = [
|
|||
"numpy==1.23.5; platform_machine == 'aarch64'", # fix error: `np.bool` was a deprecated alias for the builtin `bool` when using TensorRT models on NVIDIA Jetson
|
||||
"h5py!=3.11.0; platform_machine == 'aarch64'", # fix h5py build issues due to missing aarch64 wheels in 3.11 release
|
||||
]
|
||||
explorer = [
|
||||
"lancedb", # vector search
|
||||
"duckdb<=0.9.2", # SQL queries, duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
|
||||
"streamlit", # visualizing with GUI
|
||||
solutions = [
|
||||
"shapely>=2.0.0", # shapely for point and polygon data matching
|
||||
"streamlit", # for live inference on web browser i.e `yolo streamlit-predict`
|
||||
]
|
||||
logging = [
|
||||
"comet", # https://docs.ultralytics.com/integrations/comet/
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import os
|
|||
if not os.environ.get("OMP_NUM_THREADS"):
|
||||
os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld
|
||||
from ultralytics.utils import ASSETS, SETTINGS
|
||||
from ultralytics.utils.checks import check_yolo as checks
|
||||
|
|
@ -27,5 +26,4 @@ __all__ = (
|
|||
"checks",
|
||||
"download",
|
||||
"settings",
|
||||
"Explorer",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -80,13 +80,10 @@ CLI_HELP_MSG = f"""
|
|||
4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
|
||||
yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
|
||||
|
||||
5. Explore your datasets using semantic search and SQL with a simple GUI powered by Ultralytics Explorer API
|
||||
yolo explorer data=data.yaml model=yolo11n.pt
|
||||
|
||||
6. Streamlit real-time webcam inference GUI
|
||||
5. Streamlit real-time webcam inference GUI
|
||||
yolo streamlit-predict
|
||||
|
||||
7. Run special commands:
|
||||
6. Run special commands:
|
||||
yolo help
|
||||
yolo checks
|
||||
yolo version
|
||||
|
|
@ -546,35 +543,6 @@ def handle_yolo_settings(args: List[str]) -> None:
|
|||
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
|
||||
|
||||
|
||||
def handle_explorer(args: List[str]):
|
||||
"""
|
||||
Launches a graphical user interface that provides tools for interacting with and analyzing datasets using the
|
||||
Ultralytics Explorer API. It checks for the required 'streamlit' package and informs the user that the Explorer
|
||||
dashboard is loading.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of optional command line arguments.
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
yolo explorer data=data.yaml model=yolo11n.pt
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Requires 'streamlit' package version 1.29.0 or higher.
|
||||
- The function does not take any arguments or return any values.
|
||||
- It is typically called from the command line interface using the 'yolo explorer' command.
|
||||
"""
|
||||
checks.check_requirements("streamlit>=1.29.0")
|
||||
LOGGER.info("💡 Loading Explorer dashboard...")
|
||||
cmd = ["streamlit", "run", ROOT / "data/explorer/gui/dash.py", "--server.maxMessageSize", "2048"]
|
||||
new = dict(parse_key_value_pair(a) for a in args)
|
||||
check_dict_alignment(base={k: DEFAULT_CFG_DICT[k] for k in ["model", "data"]}, custom=new)
|
||||
for k, v in new.items():
|
||||
cmd += [k, v]
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def handle_streamlit_inference():
|
||||
"""
|
||||
Open the Ultralytics Live Inference Streamlit app for real-time object detection.
|
||||
|
|
@ -715,7 +683,6 @@ def entrypoint(debug=""):
|
|||
"login": lambda: handle_yolo_hub(args),
|
||||
"logout": lambda: handle_yolo_hub(args),
|
||||
"copy-cfg": copy_default_cfg,
|
||||
"explorer": lambda: handle_explorer(args[1:]),
|
||||
"streamlit-predict": lambda: handle_streamlit_inference(),
|
||||
}
|
||||
full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special}
|
||||
|
|
|
|||
|
|
@ -1,5 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .utils import plot_query_result
|
||||
|
||||
__all__ = ["plot_query_result"]
|
||||
|
|
@ -1,460 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.data.augment import Format
|
||||
from ultralytics.data.dataset import YOLODataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.model import YOLO
|
||||
from ultralytics.utils import LOGGER, USER_CONFIG_DIR, IterableSimpleNamespace, checks
|
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
"""Extends YOLODataset for advanced data exploration and manipulation in model training workflows."""
|
||||
|
||||
def __init__(self, *args, data: dict = None, **kwargs) -> None:
|
||||
"""Initializes the ExplorerDataset with the provided data arguments, extending the YOLODataset class."""
|
||||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
def load_image(self, i: int) -> Union[Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]], Tuple[None, None, None]]:
|
||||
"""Loads 1 image from dataset index 'i' without any resize ops."""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def build_transforms(self, hyp: IterableSimpleNamespace = None):
|
||||
"""Creates transforms for dataset images without resizing."""
|
||||
return Format(
|
||||
bbox_format="xyxy",
|
||||
normalize=False,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
|
||||
|
||||
class Explorer:
|
||||
"""Utility class for image embedding, table creation, and similarity querying using LanceDB and YOLO models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Union[str, Path] = "coco128.yaml",
|
||||
model: str = "yolov8n.pt",
|
||||
uri: str = USER_CONFIG_DIR / "explorer",
|
||||
) -> None:
|
||||
"""Initializes the Explorer class with dataset path, model, and URI for database connection."""
|
||||
# Note duckdb==0.10.0 bug https://github.com/ultralytics/ultralytics/pull/8181
|
||||
checks.check_requirements(["lancedb>=0.4.3", "duckdb<=0.9.2"])
|
||||
import lancedb
|
||||
|
||||
self.connection = lancedb.connect(uri)
|
||||
self.table_name = f"{Path(data).name.lower()}_{model.lower()}"
|
||||
self.sim_idx_base_name = (
|
||||
f"{self.table_name}_sim_idx".lower()
|
||||
) # Use this name and append thres and top_k to reuse the table
|
||||
self.model = YOLO(model)
|
||||
self.data = data # None
|
||||
self.choice_set = None
|
||||
|
||||
self.table = None
|
||||
self.progress = 0
|
||||
|
||||
def create_embeddings_table(self, force: bool = False, split: str = "train") -> None:
|
||||
"""
|
||||
Create LanceDB table containing the embeddings of the images in the dataset. The table will be reused if it
|
||||
already exists. Pass force=True to overwrite the existing table.
|
||||
|
||||
Args:
|
||||
force (bool): Whether to overwrite the existing table or not. Defaults to False.
|
||||
split (str): Split of the dataset to use. Defaults to 'train'.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
```
|
||||
"""
|
||||
if self.table is not None and not force:
|
||||
LOGGER.info("Table already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
return
|
||||
if self.table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info(f"Table {self.table_name} already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
self.table = self.connection.open_table(self.table_name)
|
||||
self.progress = 1
|
||||
return
|
||||
if self.data is None:
|
||||
raise ValueError("Data must be provided to create embeddings table")
|
||||
|
||||
data_info = check_det_dataset(self.data)
|
||||
if split not in data_info:
|
||||
raise ValueError(
|
||||
f"Split {split} is not found in the dataset. Available keys in the dataset are {list(data_info.keys())}"
|
||||
)
|
||||
|
||||
choice_set = data_info[split]
|
||||
choice_set = choice_set if isinstance(choice_set, list) else [choice_set]
|
||||
self.choice_set = choice_set
|
||||
dataset = ExplorerDataset(img_path=choice_set, data=data_info, augment=False, cache=False, task=self.model.task)
|
||||
|
||||
# Create the table schema
|
||||
batch = dataset[0]
|
||||
vector_size = self.model.embed(batch["im_file"], verbose=False)[0].shape[0]
|
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode="overwrite")
|
||||
table.add(
|
||||
self._yield_batches(
|
||||
dataset,
|
||||
data_info,
|
||||
self.model,
|
||||
exclude_keys=["img", "ratio_pad", "resized_shape", "ori_shape", "batch_idx"],
|
||||
)
|
||||
)
|
||||
|
||||
self.table = table
|
||||
|
||||
def _yield_batches(self, dataset: ExplorerDataset, data_info: dict, model: YOLO, exclude_keys: List[str]):
|
||||
"""Generates batches of data for embedding, excluding specified keys."""
|
||||
for i in tqdm(range(len(dataset))):
|
||||
self.progress = float(i + 1) / len(dataset)
|
||||
batch = dataset[i]
|
||||
for k in exclude_keys:
|
||||
batch.pop(k, None)
|
||||
batch = sanitize_batch(batch, data_info)
|
||||
batch["vector"] = model.embed(batch["im_file"], verbose=False)[0].detach().tolist()
|
||||
yield [batch]
|
||||
|
||||
def query(
|
||||
self, imgs: Union[str, np.ndarray, List[str], List[np.ndarray]] = None, limit: int = 25
|
||||
) -> Any: # pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
Args:
|
||||
imgs (str or list): Path to the image or a list of paths to the images.
|
||||
limit (int): Number of results to return.
|
||||
|
||||
Returns:
|
||||
(pyarrow.Table): An arrow table containing the results. Supports converting to:
|
||||
- pandas dataframe: `result.to_pandas()`
|
||||
- dict of lists: `result.to_pydict()`
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.query(img="https://ultralytics.com/images/zidane.jpg")
|
||||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
if isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
assert isinstance(imgs, list), f"img must be a string or a list of strings. Got {type(imgs)}"
|
||||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
return self.table.search(embeds).limit(limit).to_arrow()
|
||||
|
||||
def sql_query(
|
||||
self, query: str, return_type: str = "pandas"
|
||||
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
Args:
|
||||
query (str): SQL query to run.
|
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
||||
|
||||
Returns:
|
||||
(pyarrow.Table): An arrow table containing the results.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
||||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
assert return_type in {
|
||||
"pandas",
|
||||
"arrow",
|
||||
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
|
||||
import duckdb
|
||||
|
||||
if self.table is None:
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
||||
if not query.startswith("SELECT") and not query.startswith("WHERE"):
|
||||
raise ValueError(
|
||||
f"Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE "
|
||||
f"clause. found {query}"
|
||||
)
|
||||
if query.startswith("WHERE"):
|
||||
query = f"SELECT * FROM 'table' {query}"
|
||||
LOGGER.info(f"Running query: {query}")
|
||||
|
||||
rs = duckdb.sql(query)
|
||||
if return_type == "arrow":
|
||||
return rs.arrow()
|
||||
elif return_type == "pandas":
|
||||
return rs.df()
|
||||
|
||||
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
|
||||
"""
|
||||
Plot the results of a SQL-Like query on the table.
|
||||
|
||||
Args:
|
||||
query (str): SQL query to run.
|
||||
labels (bool): Whether to plot the labels or not.
|
||||
|
||||
Returns:
|
||||
(PIL.Image): Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
query = "SELECT * FROM 'table' WHERE labels LIKE '%person%'"
|
||||
result = exp.plot_sql_query(query)
|
||||
```
|
||||
"""
|
||||
result = self.sql_query(query, return_type="arrow")
|
||||
if len(result) == 0:
|
||||
LOGGER.info("No results found.")
|
||||
return None
|
||||
img = plot_query_result(result, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(
|
||||
self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = "pandas",
|
||||
) -> Any: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
Args:
|
||||
img (str or list): Path to the image or a list of paths to the images.
|
||||
idx (int or list): Index of the image in the table or a list of indexes.
|
||||
limit (int): Number of results to return. Defaults to 25.
|
||||
return_type (str): Type of the result to return. Can be either 'pandas' or 'arrow'. Defaults to 'pandas'.
|
||||
|
||||
Returns:
|
||||
(pandas.DataFrame): A dataframe containing the results.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.get_similar(img="https://ultralytics.com/images/zidane.jpg")
|
||||
```
|
||||
"""
|
||||
assert return_type in {"pandas", "arrow"}, f"Return type should be `pandas` or `arrow`, but got {return_type}"
|
||||
img = self._check_imgs_or_idxs(img, idx)
|
||||
similar = self.query(img, limit=limit)
|
||||
|
||||
if return_type == "arrow":
|
||||
return similar
|
||||
elif return_type == "pandas":
|
||||
return similar.to_pandas()
|
||||
|
||||
def plot_similar(
|
||||
self,
|
||||
img: Union[str, np.ndarray, List[str], List[np.ndarray]] = None,
|
||||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
labels: bool = True,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Plot the similar images. Accepts images or indexes.
|
||||
|
||||
Args:
|
||||
img (str or list): Path to the image or a list of paths to the images.
|
||||
idx (int or list): Index of the image in the table or a list of indexes.
|
||||
labels (bool): Whether to plot the labels or not.
|
||||
limit (int): Number of results to return. Defaults to 25.
|
||||
|
||||
Returns:
|
||||
(PIL.Image): Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
similar = exp.plot_similar(img="https://ultralytics.com/images/zidane.jpg")
|
||||
```
|
||||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type="arrow")
|
||||
if len(similar) == 0:
|
||||
LOGGER.info("No results found.")
|
||||
return None
|
||||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
||||
"""
|
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
||||
are max_dist or closer to the image in the embedding space at a given index.
|
||||
|
||||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit.
|
||||
vector search. Defaults: None.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(pandas.DataFrame): A dataframe containing the similarity index. Each row corresponds to an image,
|
||||
and columns include indices of similar images and their respective distances.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
sim_idx = exp.similarity_index()
|
||||
```
|
||||
"""
|
||||
if self.table is None:
|
||||
raise ValueError("Table is not created. Please create the table first.")
|
||||
sim_idx_table_name = f"{self.sim_idx_base_name}_thres_{max_dist}_top_{top_k}".lower()
|
||||
if sim_idx_table_name in self.connection.table_names() and not force:
|
||||
LOGGER.info("Similarity matrix already exists. Reusing it. Pass force=True to overwrite it.")
|
||||
return self.connection.open_table(sim_idx_table_name).to_pandas()
|
||||
|
||||
if top_k and not (1.0 >= top_k >= 0.0):
|
||||
raise ValueError(f"top_k must be between 0.0 and 1.0. Got {top_k}")
|
||||
if max_dist < 0.0:
|
||||
raise ValueError(f"max_dist must be greater than 0. Got {max_dist}")
|
||||
|
||||
top_k = int(top_k * len(self.table)) if top_k else len(self.table)
|
||||
top_k = max(top_k, 1)
|
||||
features = self.table.to_lance().to_table(columns=["vector", "im_file"]).to_pydict()
|
||||
im_files = features["im_file"]
|
||||
embeddings = features["vector"]
|
||||
|
||||
sim_table = self.connection.create_table(sim_idx_table_name, schema=get_sim_index_schema(), mode="overwrite")
|
||||
|
||||
def _yield_sim_idx():
|
||||
"""Generates a dataframe with similarity indices and distances for images."""
|
||||
for i in tqdm(range(len(embeddings))):
|
||||
sim_idx = self.table.search(embeddings[i]).limit(top_k).to_pandas().query(f"_distance <= {max_dist}")
|
||||
yield [
|
||||
{
|
||||
"idx": i,
|
||||
"im_file": im_files[i],
|
||||
"count": len(sim_idx),
|
||||
"sim_im_files": sim_idx["im_file"].tolist(),
|
||||
}
|
||||
]
|
||||
|
||||
sim_table.add(_yield_sim_idx())
|
||||
self.sim_index = sim_table
|
||||
return sim_table.to_pandas()
|
||||
|
||||
def plot_similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Image:
|
||||
"""
|
||||
Plot the similarity index of all the images in the table. Here, the index will contain the data points that are
|
||||
max_dist or closer to the image in the embedding space at a given index.
|
||||
|
||||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of closest data points to consider when counting. Used to apply limit when
|
||||
running vector search. Defaults to 0.01.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
(PIL.Image): Image containing the plot.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
|
||||
similarity_idx_plot = exp.plot_similarity_index()
|
||||
similarity_idx_plot.show() # view image preview
|
||||
similarity_idx_plot.save("path/to/save/similarity_index_plot.png") # save contents to file
|
||||
```
|
||||
"""
|
||||
sim_idx = self.similarity_index(max_dist=max_dist, top_k=top_k, force=force)
|
||||
sim_count = sim_idx["count"].tolist()
|
||||
sim_count = np.array(sim_count)
|
||||
|
||||
indices = np.arange(len(sim_count))
|
||||
|
||||
# Create the bar plot
|
||||
plt.bar(indices, sim_count)
|
||||
|
||||
# Customize the plot (optional)
|
||||
plt.xlabel("data idx")
|
||||
plt.ylabel("Count")
|
||||
plt.title("Similarity Count")
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format="png")
|
||||
buffer.seek(0)
|
||||
|
||||
# Use Pillow to open the image from the buffer
|
||||
return Image.fromarray(np.array(Image.open(buffer)))
|
||||
|
||||
def _check_imgs_or_idxs(
|
||||
self, img: Union[str, np.ndarray, List[str], List[np.ndarray], None], idx: Union[None, int, List[int]]
|
||||
) -> List[np.ndarray]:
|
||||
"""Determines whether to fetch images or indexes based on provided arguments and returns image paths."""
|
||||
if img is None and idx is None:
|
||||
raise ValueError("Either img or idx must be provided.")
|
||||
if img is not None and idx is not None:
|
||||
raise ValueError("Only one of img or idx must be provided.")
|
||||
if idx is not None:
|
||||
idx = idx if isinstance(idx, list) else [idx]
|
||||
img = self.table.to_lance().take(idx, columns=["im_file"]).to_pydict()["im_file"]
|
||||
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
def ask_ai(self, query):
|
||||
"""
|
||||
Ask AI a question.
|
||||
|
||||
Args:
|
||||
query (str): Question to ask.
|
||||
|
||||
Returns:
|
||||
(pandas.DataFrame): A dataframe containing filtered results to the SQL query.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
answer = exp.ask_ai("Show images with 1 person and 2 dogs")
|
||||
```
|
||||
"""
|
||||
result = prompt_sql_query(query)
|
||||
try:
|
||||
return self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
|
|
@ -1 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
|
@ -1,269 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements(("streamlit>=1.29.0", "streamlit-select>=0.3"))
|
||||
|
||||
import streamlit as st
|
||||
from streamlit_select import image_select
|
||||
|
||||
|
||||
def _get_explorer():
|
||||
"""Initializes and returns an instance of the Explorer class."""
|
||||
exp = Explorer(data=st.session_state.get("dataset"), model=st.session_state.get("model"))
|
||||
thread = Thread(
|
||||
target=exp.create_embeddings_table,
|
||||
kwargs={"force": st.session_state.get("force_recreate_embeddings"), "split": st.session_state.get("split")},
|
||||
)
|
||||
thread.start()
|
||||
progress_bar = st.progress(0, text="Creating embeddings table...")
|
||||
while exp.progress < 1:
|
||||
time.sleep(0.1)
|
||||
progress_bar.progress(exp.progress, text=f"Progress: {exp.progress * 100}%")
|
||||
thread.join()
|
||||
st.session_state["explorer"] = exp
|
||||
progress_bar.empty()
|
||||
|
||||
|
||||
def init_explorer_form(data=None, model=None):
|
||||
"""Initializes an Explorer instance and creates embeddings table with progress tracking."""
|
||||
if data is None:
|
||||
datasets = ROOT / "cfg" / "datasets"
|
||||
ds = [d.name for d in datasets.glob("*.yaml")]
|
||||
else:
|
||||
ds = [data]
|
||||
|
||||
prefixes = ["yolov8", "yolo11"]
|
||||
sizes = ["n", "s", "m", "l", "x"]
|
||||
tasks = ["", "-seg", "-pose"]
|
||||
if model is None:
|
||||
models = [f"{p}{s}{t}" for p in prefixes for s in sizes for t in tasks]
|
||||
else:
|
||||
models = [model]
|
||||
|
||||
splits = ["train", "val", "test"]
|
||||
|
||||
with st.form(key="explorer_init_form"):
|
||||
col1, col2, col3 = st.columns(3)
|
||||
with col1:
|
||||
st.selectbox("Select dataset", ds, key="dataset")
|
||||
with col2:
|
||||
st.selectbox("Select model", models, key="model")
|
||||
with col3:
|
||||
st.selectbox("Select split", splits, key="split")
|
||||
st.checkbox("Force recreate embeddings", key="force_recreate_embeddings")
|
||||
|
||||
st.form_submit_button("Explore", on_click=_get_explorer)
|
||||
|
||||
|
||||
def query_form():
|
||||
"""Sets up a form in Streamlit to initialize Explorer with dataset and model selection."""
|
||||
with st.form("query_form"):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input(
|
||||
"Query",
|
||||
"WHERE labels LIKE '%person%' AND labels LIKE '%dog%'",
|
||||
label_visibility="collapsed",
|
||||
key="query",
|
||||
)
|
||||
with col2:
|
||||
st.form_submit_button("Query", on_click=run_sql_query)
|
||||
|
||||
|
||||
def ai_query_form():
|
||||
"""Sets up a Streamlit form for user input to initialize Explorer with dataset and model selection."""
|
||||
with st.form("ai_query_form"):
|
||||
col1, col2 = st.columns([0.8, 0.2])
|
||||
with col1:
|
||||
st.text_input("Query", "Show images with 1 person and 1 dog", label_visibility="collapsed", key="ai_query")
|
||||
with col2:
|
||||
st.form_submit_button("Ask AI", on_click=run_ai_query)
|
||||
|
||||
|
||||
def find_similar_imgs(imgs):
|
||||
"""Initializes a Streamlit form for AI-based image querying with custom input."""
|
||||
exp = st.session_state["explorer"]
|
||||
similar = exp.get_similar(img=imgs, limit=st.session_state.get("limit"), return_type="arrow")
|
||||
paths = similar.to_pydict()["im_file"]
|
||||
st.session_state["imgs"] = paths
|
||||
st.session_state["res"] = similar
|
||||
|
||||
|
||||
def similarity_form(selected_imgs):
|
||||
"""Initializes a form for AI-based image querying with custom input in Streamlit."""
|
||||
st.write("Similarity Search")
|
||||
with st.form("similarity_form"):
|
||||
subcol1, subcol2 = st.columns([1, 1])
|
||||
with subcol1:
|
||||
st.number_input(
|
||||
"limit", min_value=None, max_value=None, value=25, label_visibility="collapsed", key="limit"
|
||||
)
|
||||
|
||||
with subcol2:
|
||||
disabled = not len(selected_imgs)
|
||||
st.write("Selected: ", len(selected_imgs))
|
||||
st.form_submit_button(
|
||||
"Search",
|
||||
disabled=disabled,
|
||||
on_click=find_similar_imgs,
|
||||
args=(selected_imgs,),
|
||||
)
|
||||
if disabled:
|
||||
st.error("Select at least one image to search.")
|
||||
|
||||
|
||||
# def persist_reset_form():
|
||||
# with st.form("persist_reset"):
|
||||
# col1, col2 = st.columns([1, 1])
|
||||
# with col1:
|
||||
# st.form_submit_button("Reset", on_click=reset)
|
||||
#
|
||||
# with col2:
|
||||
# st.form_submit_button("Persist", on_click=update_state, args=("PERSISTING", True))
|
||||
|
||||
|
||||
def run_sql_query():
|
||||
"""Executes an SQL query and returns the results."""
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.sql_query(query, return_type="arrow")
|
||||
st.session_state["imgs"] = res.to_pydict()["im_file"]
|
||||
st.session_state["res"] = res
|
||||
|
||||
|
||||
def run_ai_query():
|
||||
"""Execute SQL query and update session state with query results."""
|
||||
if not SETTINGS["openai_api_key"]:
|
||||
st.session_state["error"] = (
|
||||
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
)
|
||||
return
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("ai_query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pandas.DataFrame) or res.empty:
|
||||
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
||||
return
|
||||
st.session_state["imgs"] = res["im_file"].to_list()
|
||||
st.session_state["res"] = res
|
||||
|
||||
|
||||
def reset_explorer():
|
||||
"""Resets the explorer to its initial state by clearing session variables."""
|
||||
st.session_state["explorer"] = None
|
||||
st.session_state["imgs"] = None
|
||||
st.session_state["error"] = None
|
||||
|
||||
|
||||
def utralytics_explorer_docs_callback():
|
||||
"""Resets the explorer to its initial state by clearing session variables."""
|
||||
with st.container(border=True):
|
||||
st.image(
|
||||
"https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg",
|
||||
width=100,
|
||||
)
|
||||
st.markdown(
|
||||
"<p>This demo is built using Ultralytics Explorer API. Visit <a href='https://docs.ultralytics.com/datasets/explorer/'>API docs</a> to try examples & learn more</p>",
|
||||
unsafe_allow_html=True,
|
||||
help=None,
|
||||
)
|
||||
st.link_button("Ultrlaytics Explorer API", "https://docs.ultralytics.com/datasets/explorer/")
|
||||
|
||||
|
||||
def layout(data=None, model=None):
|
||||
"""Resets explorer session variables and provides documentation with a link to API docs."""
|
||||
st.set_page_config(layout="wide", initial_sidebar_state="collapsed")
|
||||
st.markdown("<h1 style='text-align: center;'>Ultralytics Explorer Demo</h1>", unsafe_allow_html=True)
|
||||
|
||||
if st.session_state.get("explorer") is None:
|
||||
init_explorer_form(data, model)
|
||||
return
|
||||
|
||||
st.button(":arrow_backward: Select Dataset", on_click=reset_explorer)
|
||||
exp = st.session_state.get("explorer")
|
||||
col1, col2 = st.columns([0.75, 0.25], gap="small")
|
||||
imgs = []
|
||||
if st.session_state.get("error"):
|
||||
st.error(st.session_state["error"])
|
||||
elif st.session_state.get("imgs"):
|
||||
imgs = st.session_state.get("imgs")
|
||||
else:
|
||||
imgs = exp.table.to_lance().to_table(columns=["im_file"]).to_pydict()["im_file"]
|
||||
st.session_state["res"] = exp.table.to_arrow()
|
||||
total_imgs, selected_imgs = len(imgs), []
|
||||
with col1:
|
||||
subcol1, subcol2, subcol3, subcol4, subcol5 = st.columns(5)
|
||||
with subcol1:
|
||||
st.write("Max Images Displayed:")
|
||||
with subcol2:
|
||||
num = st.number_input(
|
||||
"Max Images Displayed",
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=min(500, total_imgs),
|
||||
key="num_imgs_displayed",
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
with subcol3:
|
||||
st.write("Start Index:")
|
||||
with subcol4:
|
||||
start_idx = st.number_input(
|
||||
"Start Index",
|
||||
min_value=0,
|
||||
max_value=total_imgs,
|
||||
value=0,
|
||||
key="start_index",
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
with subcol5:
|
||||
reset = st.button("Reset", use_container_width=False, key="reset")
|
||||
if reset:
|
||||
st.session_state["imgs"] = None
|
||||
st.experimental_rerun()
|
||||
|
||||
query_form()
|
||||
ai_query_form()
|
||||
if total_imgs:
|
||||
labels, boxes, masks, kpts, classes = None, None, None, None, None
|
||||
task = exp.model.task
|
||||
if st.session_state.get("display_labels"):
|
||||
labels = st.session_state.get("res").to_pydict()["labels"][start_idx : start_idx + num]
|
||||
boxes = st.session_state.get("res").to_pydict()["bboxes"][start_idx : start_idx + num]
|
||||
masks = st.session_state.get("res").to_pydict()["masks"][start_idx : start_idx + num]
|
||||
kpts = st.session_state.get("res").to_pydict()["keypoints"][start_idx : start_idx + num]
|
||||
classes = st.session_state.get("res").to_pydict()["cls"][start_idx : start_idx + num]
|
||||
imgs_displayed = imgs[start_idx : start_idx + num]
|
||||
selected_imgs = image_select(
|
||||
f"Total samples: {total_imgs}",
|
||||
images=imgs_displayed,
|
||||
use_container_width=False,
|
||||
# indices=[i for i in range(num)] if select_all else None,
|
||||
labels=labels,
|
||||
classes=classes,
|
||||
bboxes=boxes,
|
||||
masks=masks if task == "segment" else None,
|
||||
kpts=kpts if task == "pose" else None,
|
||||
)
|
||||
|
||||
with col2:
|
||||
similarity_form(selected_imgs)
|
||||
st.checkbox("Labels", value=False, key="display_labels")
|
||||
utralytics_explorer_docs_callback()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
kwargs = dict(zip(sys.argv[1::2], sys.argv[2::2]))
|
||||
layout(**kwargs)
|
||||
|
|
@ -1,167 +0,0 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import getpass
|
||||
from typing import List
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER as logger
|
||||
from ultralytics.utils import SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.ops import xyxy2xywh
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
def get_table_schema(vector_size):
|
||||
"""Extracts and returns the schema of a database table."""
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
class Schema(LanceModel):
|
||||
im_file: str
|
||||
labels: List[str]
|
||||
cls: List[int]
|
||||
bboxes: List[List[float]]
|
||||
masks: List[List[List[int]]]
|
||||
keypoints: List[List[List[float]]]
|
||||
vector: Vector(vector_size)
|
||||
|
||||
return Schema
|
||||
|
||||
|
||||
def get_sim_index_schema():
|
||||
"""Returns a LanceModel schema for a database table with specified vector size."""
|
||||
from lancedb.pydantic import LanceModel
|
||||
|
||||
class Schema(LanceModel):
|
||||
idx: int
|
||||
im_file: str
|
||||
count: int
|
||||
sim_im_files: List[str]
|
||||
|
||||
return Schema
|
||||
|
||||
|
||||
def sanitize_batch(batch, dataset_info):
|
||||
"""Sanitizes input batch for inference, ensuring correct format and dimensions."""
|
||||
batch["cls"] = batch["cls"].flatten().int().tolist()
|
||||
box_cls_pair = sorted(zip(batch["bboxes"].tolist(), batch["cls"]), key=lambda x: x[1])
|
||||
batch["bboxes"] = [box for box, _ in box_cls_pair]
|
||||
batch["cls"] = [cls for _, cls in box_cls_pair]
|
||||
batch["labels"] = [dataset_info["names"][i] for i in batch["cls"]]
|
||||
batch["masks"] = batch["masks"].tolist() if "masks" in batch else [[[]]]
|
||||
batch["keypoints"] = batch["keypoints"].tolist() if "keypoints" in batch else [[[]]]
|
||||
return batch
|
||||
|
||||
|
||||
def plot_query_result(similar_set, plot_labels=True):
|
||||
"""
|
||||
Plot images from the similar set.
|
||||
|
||||
Args:
|
||||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
similar_set = (
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict()
|
||||
)
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
images = similar_set.get("im_file", [])
|
||||
bboxes = similar_set.get("bboxes", []) if similar_set.get("bboxes") is not empty_boxes else []
|
||||
masks = similar_set.get("masks") if similar_set.get("masks")[0] != empty_masks else []
|
||||
kpts = similar_set.get("keypoints") if similar_set.get("keypoints")[0] != empty_masks else []
|
||||
cls = similar_set.get("cls", [])
|
||||
|
||||
plot_size = 640
|
||||
imgs, batch_idx, plot_boxes, plot_masks, plot_kpts = [], [], [], [], []
|
||||
for i, imf in enumerate(images):
|
||||
im = cv2.imread(imf)
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
h, w = im.shape[:2]
|
||||
r = min(plot_size / h, plot_size / w)
|
||||
imgs.append(LetterBox(plot_size, center=False)(image=im).transpose(2, 0, 1))
|
||||
if plot_labels:
|
||||
if len(bboxes) > i and len(bboxes[i]) > 0:
|
||||
box = np.array(bboxes[i], dtype=np.float32)
|
||||
box[:, [0, 2]] *= r
|
||||
box[:, [1, 3]] *= r
|
||||
plot_boxes.append(box)
|
||||
if len(masks) > i and len(masks[i]) > 0:
|
||||
mask = np.array(masks[i], dtype=np.uint8)[0]
|
||||
plot_masks.append(LetterBox(plot_size, center=False)(image=mask))
|
||||
if len(kpts) > i and kpts[i] is not None:
|
||||
kpt = np.array(kpts[i], dtype=np.float32)
|
||||
kpt[:, :, :2] *= r
|
||||
plot_kpts.append(kpt)
|
||||
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
|
||||
imgs = np.stack(imgs, axis=0)
|
||||
masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
|
||||
kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
|
||||
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
|
||||
batch_idx = np.concatenate(batch_idx, axis=0)
|
||||
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)
|
||||
|
||||
return plot_images(
|
||||
imgs, batch_idx, cls, bboxes=boxes, masks=masks, kpts=kpts, max_subplots=len(images), save=False, threaded=False
|
||||
)
|
||||
|
||||
|
||||
def prompt_sql_query(query):
|
||||
"""Plots images with optional labels from a similar data set."""
|
||||
check_requirements("openai>=1.6.1")
|
||||
from openai import OpenAI
|
||||
|
||||
if not SETTINGS["openai_api_key"]:
|
||||
logger.warning("OpenAI API key not found in settings. Please enter your API key below.")
|
||||
openai_api_key = getpass.getpass("OpenAI API key: ")
|
||||
SETTINGS.update({"openai_api_key": openai_api_key})
|
||||
openai = OpenAI(api_key=SETTINGS["openai_api_key"])
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
You are a helpful data scientist proficient in SQL. You need to output exactly one SQL query based on
|
||||
the following schema and a user request. You only need to output the format with fixed selection
|
||||
statement that selects everything from "'table'", like `SELECT * from 'table'`
|
||||
|
||||
Schema:
|
||||
im_file: string not null
|
||||
labels: list<item: string> not null
|
||||
child 0, item: string
|
||||
cls: list<item: int64> not null
|
||||
child 0, item: int64
|
||||
bboxes: list<item: list<item: double>> not null
|
||||
child 0, item: list<item: double>
|
||||
child 0, item: double
|
||||
masks: list<item: list<item: list<item: int64>>> not null
|
||||
child 0, item: list<item: list<item: int64>>
|
||||
child 0, item: list<item: int64>
|
||||
child 0, item: int64
|
||||
keypoints: list<item: list<item: list<item: double>>> not null
|
||||
child 0, item: list<item: list<item: double>>
|
||||
child 0, item: list<item: double>
|
||||
child 0, item: double
|
||||
vector: fixed_size_list<item: float>[256] not null
|
||||
child 0, item: float
|
||||
|
||||
Some details about the schema:
|
||||
- the "labels" column contains the string values like 'person' and 'dog' for the respective objects
|
||||
in each image
|
||||
- the "cls" column contains the integer values on these classes that map them the labels
|
||||
|
||||
Example of a correct query:
|
||||
request - Get all data points that contain 2 or more people and at least one dog
|
||||
correct query-
|
||||
SELECT * FROM 'table' WHERE ARRAY_LENGTH(cls) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'person')) >= 2 AND ARRAY_LENGTH(FILTER(labels, x -> x = 'dog')) >= 1;
|
||||
""",
|
||||
},
|
||||
{"role": "user", "content": f"{query}"},
|
||||
]
|
||||
|
||||
response = openai.chat.completions.create(model="gpt-3.5-turbo", messages=messages)
|
||||
return response.choices[0].message.content
|
||||
Loading…
Add table
Add a link
Reference in a new issue