ultralytics 8.1.43 40% faster ultralytics imports (#9547)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
99c61d6f7b
commit
a2628657a1
21 changed files with 240 additions and 225 deletions
|
|
@ -8,7 +8,7 @@ from typing import Tuple, Union
|
|||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOGGER, colorstr
|
||||
from ultralytics.utils.checks import check_version
|
||||
|
|
@ -167,8 +167,8 @@ class BaseMixTransform:
|
|||
text2id = {text: i for i, text in enumerate(mix_texts)}
|
||||
|
||||
for label in [labels] + labels["mix_labels"]:
|
||||
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(l)]
|
||||
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
|
||||
text = label["texts"][int(cls)]
|
||||
label["cls"][i] = text2id[tuple(text)]
|
||||
label["texts"] = mix_texts
|
||||
return labels
|
||||
|
|
@ -1133,7 +1133,7 @@ def classify_transforms(
|
|||
size=224,
|
||||
mean=DEFAULT_MEAN,
|
||||
std=DEFAULT_STD,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
interpolation=Image.BILINEAR,
|
||||
crop_fraction: float = DEFAULT_CROP_FTACTION,
|
||||
):
|
||||
"""
|
||||
|
|
@ -1149,6 +1149,7 @@ def classify_transforms(
|
|||
Returns:
|
||||
(T.Compose): torchvision transforms
|
||||
"""
|
||||
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
||||
|
||||
if isinstance(size, (tuple, list)):
|
||||
assert len(size) == 2
|
||||
|
|
@ -1157,12 +1158,12 @@ def classify_transforms(
|
|||
scale_size = math.floor(size / crop_fraction)
|
||||
scale_size = (scale_size, scale_size)
|
||||
|
||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
if scale_size[0] == scale_size[1]:
|
||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
|
||||
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
|
||||
else:
|
||||
# resize shortest edge to matching target dim for non-square target
|
||||
# Resize the shortest edge to matching target dim for non-square target
|
||||
tfl = [T.Resize(scale_size)]
|
||||
tfl += [T.CenterCrop(size)]
|
||||
|
||||
|
|
@ -1192,7 +1193,7 @@ def classify_augmentations(
|
|||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
force_color_jitter=False,
|
||||
erasing=0.0,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
interpolation=Image.BILINEAR,
|
||||
):
|
||||
"""
|
||||
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
|
||||
|
|
@ -1216,7 +1217,9 @@ def classify_augmentations(
|
|||
Returns:
|
||||
(T.Compose): torchvision transforms
|
||||
"""
|
||||
# Transforms to apply if albumentations not installed
|
||||
# Transforms to apply if Albumentations not installed
|
||||
import torchvision.transforms as T # scope for faster 'import ultralytics'
|
||||
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
|
|
|
|||
|
|
@ -1,18 +1,17 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
from itertools import repeat
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr
|
||||
from ultralytics.utils.ops import resample_segments
|
||||
from .augment import (
|
||||
|
|
@ -103,16 +102,16 @@ class YOLODataset(BaseDataset):
|
|||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=shape,
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
segments=segments,
|
||||
keypoints=keypoint,
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
)
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": shape,
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"segments": segments,
|
||||
"keypoints": keypoint,
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
}
|
||||
)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
|
|
@ -245,125 +244,6 @@ class YOLODataset(BaseDataset):
|
|||
return new_batch
|
||||
|
||||
|
||||
# Classification dataloaders -------------------------------------------------------------------------------------------
|
||||
class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||
"""
|
||||
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
||||
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
||||
learning models, with optional image transformations and caching mechanisms to speed up training.
|
||||
|
||||
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
||||
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
||||
to ensure data integrity and consistency.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, prefix=""):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
||||
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
||||
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
||||
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
||||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
||||
debugging. Default is an empty string.
|
||||
"""
|
||||
super().__init__(root=root)
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram:
|
||||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
return len(self.samples)
|
||||
|
||||
def verify_images(self):
|
||||
"""Verify all images in dataset."""
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||
for sample, nf_f, nc_f, msg in pbar:
|
||||
if nf_f:
|
||||
samples.append(sample)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
||||
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
|
@ -447,15 +327,15 @@ class GroundingDataset(YOLODataset):
|
|||
bboxes.append(box)
|
||||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||
labels.append(
|
||||
dict(
|
||||
im_file=im_file,
|
||||
shape=(h, w),
|
||||
cls=lb[:, 0:1], # n, 1
|
||||
bboxes=lb[:, 1:], # n, 4
|
||||
normalized=True,
|
||||
bbox_format="xywh",
|
||||
texts=texts,
|
||||
)
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": (h, w),
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
"texts": texts,
|
||||
}
|
||||
)
|
||||
return labels
|
||||
|
||||
|
|
@ -497,3 +377,128 @@ class SemanticDataset(BaseDataset):
|
|||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
super().__init__()
|
||||
|
||||
|
||||
class ClassificationDataset:
|
||||
"""
|
||||
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
|
||||
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
|
||||
learning models, with optional image transformations and caching mechanisms to speed up training.
|
||||
|
||||
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
|
||||
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
|
||||
to ensure data integrity and consistency.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||
"""
|
||||
|
||||
def __init__(self, root, args, augment=False, prefix=""):
|
||||
"""
|
||||
Initialize YOLO object with root, image size, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
|
||||
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
|
||||
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
|
||||
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
|
||||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
|
||||
debugging. Default is an empty string.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
||||
self.base = torchvision.datasets.ImageFolder(root=root)
|
||||
self.samples = self.base.samples
|
||||
self.root = self.base.root
|
||||
|
||||
# Initialize attributes
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram:
|
||||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
return len(self.samples)
|
||||
|
||||
def verify_images(self):
|
||||
"""Verify all images in dataset."""
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError):
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||
for sample, nf_f, nc_f, msg in pbar:
|
||||
if nf_f:
|
||||
samples.append(sample)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import numpy as np
|
|||
import torch
|
||||
from PIL import Image
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas import DataFrame
|
||||
from tqdm import tqdm
|
||||
|
||||
from ultralytics.data.augment import Format
|
||||
|
|
@ -172,7 +171,7 @@ class Explorer:
|
|||
|
||||
def sql_query(
|
||||
self, query: str, return_type: str = "pandas"
|
||||
) -> Union[DataFrame, Any, None]: # pandas.dataframe or pyarrow.Table
|
||||
) -> Union[Any, None]: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Run a SQL-Like query on the table. Utilizes LanceDB predicate pushdown.
|
||||
|
||||
|
|
@ -247,7 +246,7 @@ class Explorer:
|
|||
idx: Union[int, List[int]] = None,
|
||||
limit: int = 25,
|
||||
return_type: str = "pandas",
|
||||
) -> Union[DataFrame, Any]: # pandas.dataframe or pyarrow.Table
|
||||
) -> Any: # pandas.DataFrame or pyarrow.Table
|
||||
"""
|
||||
Query the table for similar images. Accepts a single image or a list of images.
|
||||
|
||||
|
|
@ -312,7 +311,7 @@ class Explorer:
|
|||
img = plot_query_result(similar, plot_labels=labels)
|
||||
return Image.fromarray(img)
|
||||
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> DataFrame:
|
||||
def similarity_index(self, max_dist: float = 0.2, top_k: float = None, force: bool = False) -> Any: # pd.DataFrame
|
||||
"""
|
||||
Calculate the similarity index of all the images in the table. Here, the index will contain the data points that
|
||||
are max_dist or closer to the image in the embedding space at a given index.
|
||||
|
|
@ -447,12 +446,11 @@ class Explorer:
|
|||
"""
|
||||
result = prompt_sql_query(query)
|
||||
try:
|
||||
df = self.sql_query(result)
|
||||
return self.sql_query(result)
|
||||
except Exception as e:
|
||||
LOGGER.error("AI generated query is not valid. Please try again with a different prompt")
|
||||
LOGGER.error(e)
|
||||
return None
|
||||
return df
|
||||
|
||||
def visualize(self, result):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -3,8 +3,6 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics import Explorer
|
||||
from ultralytics.utils import ROOT, SETTINGS
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
|
@ -148,12 +146,14 @@ def run_ai_query():
|
|||
'OpenAI API key not found in settings. Please run yolo settings openai_api_key="..."'
|
||||
)
|
||||
return
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
st.session_state["error"] = None
|
||||
query = st.session_state.get("ai_query")
|
||||
if query.rstrip().lstrip():
|
||||
exp = st.session_state["explorer"]
|
||||
res = exp.ask_ai(query)
|
||||
if not isinstance(res, pd.DataFrame) or res.empty:
|
||||
if not isinstance(res, pandas.DataFrame) or res.empty:
|
||||
st.session_state["error"] = "No results found using AI generated query. Try another query or rerun it."
|
||||
return
|
||||
st.session_state["imgs"] = res["im_file"].to_list()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from typing import List
|
|||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.utils import LOGGER as logger
|
||||
|
|
@ -64,8 +63,10 @@ def plot_query_result(similar_set, plot_labels=True):
|
|||
similar_set (list): Pyarrow or pandas object containing the similar data points
|
||||
plot_labels (bool): Whether to plot labels or not
|
||||
"""
|
||||
import pandas # scope for faster 'import ultralytics'
|
||||
|
||||
similar_set = (
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pd.DataFrame) else similar_set.to_pydict()
|
||||
similar_set.to_dict(orient="list") if isinstance(similar_set, pandas.DataFrame) else similar_set.to_pydict()
|
||||
)
|
||||
empty_masks = [[[]]]
|
||||
empty_boxes = [[]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue