ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)

Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: uwer <uwe.rosebrock@gmail.com>
Co-authored-by: Uwe Rosebrock <ro260@csiro.au>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Laughing-q <1182102784@qq.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
Glenn Jocher 2024-01-08 23:36:29 +01:00 committed by GitHub
parent e76754eab0
commit 783033fa6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 387 additions and 76 deletions

View file

@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
class ExplorerDataset(YOLODataset):
@ -58,7 +58,7 @@ class Explorer:
data: Union[str, Path] = 'coco128.yaml',
model: str = 'yolov8n.pt',
uri: str = '~/ultralytics/explorer') -> None:
checks.check_requirements(['lancedb', 'duckdb'])
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
import lancedb
self.connection = lancedb.connect(uri)
@ -112,8 +112,7 @@ class Explorer:
# Create the table schema
batch = dataset[0]
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
Schema = get_table_schema(vector_size)
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
table.add(
self._yield_batches(dataset,
data_info,
@ -159,10 +158,7 @@ class Explorer:
raise ValueError('Table is not created. Please create the table first.')
if isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, list):
pass
else:
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
embeds = self.model.embed(imgs)
# Get avg if multiple images are passed (len > 1)
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
@ -189,16 +185,19 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
import duckdb
if self.table is None:
raise ValueError('Table is not created. Please create the table first.')
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
table = self.table.to_arrow() # noqa
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
if not query.startswith('SELECT') and not query.startswith('WHERE'):
raise ValueError(
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
)
if query.startswith('WHERE'):
query = f"SELECT * FROM 'table' {query}"
LOGGER.info(f'Running query: {query}')
@ -228,7 +227,10 @@ class Explorer:
```
"""
result = self.sql_query(query, return_type='arrow')
img = plot_similar_images(result, plot_labels=labels)
if len(result) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(result, plot_labels=labels)
return Image.fromarray(img)
def get_similar(self,
@ -255,6 +257,8 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in ['pandas',
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
@ -288,7 +292,10 @@ class Explorer:
```
"""
similar = self.get_similar(img, idx, limit, return_type='arrow')
img = plot_similar_images(similar, plot_labels=labels)
if len(similar) == 0:
LOGGER.info('No results found.')
return None
img = plot_query_result(similar, plot_labels=labels)
return Image.fromarray(img)
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
@ -299,7 +306,7 @@ class Explorer:
Args:
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
vector search. Defaults to 0.01.
vector search. Defaults: None.
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
Returns:
@ -401,6 +408,32 @@ class Explorer:
return img if isinstance(img, list) else [img]
def ask_ai(self, query):
"""
Ask AI a question.
Args:
query (str): Question to ask.
Returns:
Answer from AI.
Example:
```python
exp = Explorer()
exp.create_embeddings_table()
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
```
"""
result = prompt_sql_query(query)
try:
df = self.sql_query(result)
except Exception as e:
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
LOGGER.error(e)
return None
return df
def visualize(self, result):
"""
Visualize the results of a query.