ultralytics 8.0.238 Explorer Ask AI feature and fixes (#7408)
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: uwer <uwe.rosebrock@gmail.com> Co-authored-by: Uwe Rosebrock <ro260@csiro.au> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1182102784@qq.com> Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com> Co-authored-by: AdamP <adamp87hun@gmail.com>
This commit is contained in:
parent
e76754eab0
commit
783033fa6b
19 changed files with 387 additions and 76 deletions
|
|
@ -16,7 +16,7 @@ from ultralytics.data.utils import check_det_dataset
|
|||
from ultralytics.models.yolo.model import YOLO
|
||||
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
|
||||
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_similar_images, sanitize_batch
|
||||
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
|
||||
|
||||
|
||||
class ExplorerDataset(YOLODataset):
|
||||
|
|
@ -58,7 +58,7 @@ class Explorer:
|
|||
data: Union[str, Path] = 'coco128.yaml',
|
||||
model: str = 'yolov8n.pt',
|
||||
uri: str = '~/ultralytics/explorer') -> None:
|
||||
checks.check_requirements(['lancedb', 'duckdb'])
|
||||
checks.check_requirements(['lancedb>=0.4.3', 'duckdb'])
|
||||
import lancedb
|
||||
|
||||
self.connection = lancedb.connect(uri)
|
||||
|
|
@ -112,8 +112,7 @@ class Explorer:
|
|||
# Create the table schema
|
||||
batch = dataset[0]
|
||||
vector_size = self.model.embed(batch['im_file'], verbose=False)[0].shape[0]
|
||||
Schema = get_table_schema(vector_size)
|
||||
table = self.connection.create_table(self.table_name, schema=Schema, mode='overwrite')
|
||||
table = self.connection.create_table(self.table_name, schema=get_table_schema(vector_size), mode='overwrite')
|
||||
table.add(
|
||||
self._yield_batches(dataset,
|
||||
data_info,
|
||||
|
|
@ -159,10 +158,7 @@ class Explorer:
|
|||
raise ValueError('Table is not created. Please create the table first.')
|
||||
if isinstance(imgs, str):
|
||||
imgs = [imgs]
|
||||
elif isinstance(imgs, list):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f'img must be a string or a list of strings. Got {type(imgs)}')
|
||||
assert isinstance(imgs, list), f'img must be a string or a list of strings. Got {type(imgs)}'
|
||||
embeds = self.model.embed(imgs)
|
||||
# Get avg if multiple images are passed (len > 1)
|
||||
embeds = torch.mean(torch.stack(embeds), 0).cpu().numpy() if len(embeds) > 1 else embeds[0].cpu().numpy()
|
||||
|
|
@ -189,16 +185,19 @@ class Explorer:
|
|||
result = exp.sql_query(query)
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
import duckdb
|
||||
|
||||
if self.table is None:
|
||||
raise ValueError('Table is not created. Please create the table first.')
|
||||
|
||||
# Note: using filter pushdown would be a better long term solution. Temporarily using duckdb for this.
|
||||
table = self.table.to_arrow() # noqa
|
||||
table = self.table.to_arrow() # noqa NOTE: Don't comment this. This line is used by DuckDB
|
||||
if not query.startswith('SELECT') and not query.startswith('WHERE'):
|
||||
raise ValueError(
|
||||
'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause.')
|
||||
f'Query must start with SELECT or WHERE. You can either pass the entire query or just the WHERE clause. found {query}'
|
||||
)
|
||||
if query.startswith('WHERE'):
|
||||
query = f"SELECT * FROM 'table' {query}"
|
||||
LOGGER.info(f'Running query: {query}')
|
||||
|
|
@ -228,7 +227,10 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
result = self.sql_query(query, return_type='arrow')
|
||||
img = plot_similar_images(result, plot_labels=labels)
|
||||
if len(result) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
return None
|
||||
img = plot_query_result(result, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def get_similar(self,
|
||||
|
|
@ -255,6 +257,8 @@ class Explorer:
|
|||
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
|
||||
```
|
||||
"""
|
||||
assert return_type in ['pandas',
|
||||
'arrow'], f'Return type should be either `pandas` or `arrow`, but got {return_type}'
|
||||
img = self._check_imgs_or_idxs(img, idx)
|
||||
similar = self.query(img, limit=limit)
|
||||
|
||||
|
|
@ -288,7 +292,10 @@ class Explorer:
|
|||
```
|
||||
"""
|
||||
similar = self.get_similar(img, idx, limit, return_type='arrow')
|
||||
img = plot_similar_images(similar, plot_labels=labels)
|
||||
if len(similar) == 0:
|
||||
LOGGER.info('No results found.')
|
||||
return None
|
||||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
||||
|
|
@ -299,7 +306,7 @@ class Explorer:
|
|||
Args:
|
||||
max_dist (float): maximum L2 distance between the embeddings to consider. Defaults to 0.2.
|
||||
top_k (float): Percentage of the closest data points to consider when counting. Used to apply limit when running
|
||||
vector search. Defaults to 0.01.
|
||||
vector search. Defaults: None.
|
||||
force (bool): Whether to overwrite the existing similarity index or not. Defaults to True.
|
||||
|
||||
Returns:
|
||||
|
|
@ -401,6 +408,32 @@ class Explorer:
|
|||
|
||||
return img if isinstance(img, list) else [img]
|
||||
|
||||
def ask_ai(self, query):
|
||||
"""
|
||||
Ask AI a question.
|
||||
|
||||
Args:
|
||||
query (str): Question to ask.
|
||||
|
||||
Returns:
|
||||
Answer from AI.
|
||||
|
||||
Example:
|
||||
```python
|
||||
exp = Explorer()
|
||||
exp.create_embeddings_table()
|
||||
answer = exp.ask_ai('Show images with 1 person and 2 dogs')
|
||||
```
|
||||
"""
|
||||
result = prompt_sql_query(query)
|
||||
try:
|
||||
df = self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error('AI generated query is not valid. Please try again with a different prompt')
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
return df
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
Visualize the results of a query.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue