Python refactorings and simplifications (#7549)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Hassaan Farooq <103611273+hassaanfarooq01@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-01-12 19:34:03 +01:00 committed by GitHub
parent 0da13831cf
commit f6309b8e70
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 127 additions and 189 deletions

View file

@ -41,7 +41,7 @@ def auto_annotate(data, det_model="yolov8x.pt", sam_model="sam_b.pt", device="",
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
segments = sam_results[0].masks.xyn # noqa
with open(f"{str(Path(output_dir) / Path(result.path).stem)}.txt", "w") as f:
with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w") as f:
for i in range(len(segments)):
s = segments[i]
if len(s) == 0:

View file

@ -15,7 +15,6 @@ from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
DEFAULT_MEAN = (0.0, 0.0, 0.0)
@ -1028,7 +1027,7 @@ def classify_transforms(
if isinstance(size, (tuple, list)):
assert len(size) == 2
scale_size = tuple([math.floor(x / crop_fraction) for x in size])
scale_size = tuple(math.floor(x / crop_fraction) for x in size)
else:
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)

View file

@ -15,7 +15,6 @@ import psutil
from torch.utils.data import Dataset
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import HELP_URL, IMG_FORMATS

View file

@ -22,7 +22,6 @@ from ultralytics.data.loaders import (
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.utils import RANK, colorstr
from ultralytics.utils.checks import check_file
from .dataset import YOLODataset
from .utils import PIN_MEMORY

View file

@ -12,7 +12,6 @@ from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from ultralytics.utils.ops import resample_segments
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label

View file

@ -7,9 +7,9 @@ from typing import Any, List, Tuple, Union
import cv2
import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt
from pandas import DataFrame
from PIL import Image
from tqdm import tqdm
from ultralytics.data.augment import Format
@ -17,7 +17,6 @@ from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.model import YOLO
from ultralytics.utils import LOGGER, IterableSimpleNamespace, checks
from .utils import get_sim_index_schema, get_table_schema, plot_query_result, prompt_sql_query, sanitize_batch
@ -188,10 +187,10 @@ class Explorer:
result = exp.sql_query(query)
```
"""
assert return_type in [
assert return_type in {
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
import duckdb
if self.table is None:
@ -208,10 +207,10 @@ class Explorer:
LOGGER.info(f"Running query: {query}")
rs = duckdb.sql(query)
if return_type == "pandas":
return rs.df()
elif return_type == "arrow":
if return_type == "arrow":
return rs.arrow()
elif return_type == "pandas":
return rs.df()
def plot_sql_query(self, query: str, labels: bool = True) -> Image.Image:
"""
@ -264,17 +263,17 @@ class Explorer:
similar = exp.get_similar(img='https://ultralytics.com/images/zidane.jpg')
```
"""
assert return_type in [
assert return_type in {
"pandas",
"arrow",
], f"Return type should be either `pandas` or `arrow`, but got {return_type}"
}, f"Return type should be either `pandas` or `arrow`, but got {return_type}"
img = self._check_imgs_or_idxs(img, idx)
similar = self.query(img, limit=limit)
if return_type == "pandas":
return similar.to_pandas()
elif return_type == "arrow":
if return_type == "arrow":
return similar
elif return_type == "pandas":
return similar.to_pandas()
def plot_similar(
self,

View file

@ -98,9 +98,9 @@ def plot_query_result(similar_set, plot_labels=True):
plot_kpts.append(kpt)
batch_idx.append(np.ones(len(np.array(bboxes[i], dtype=np.float32))) * i)
imgs = np.stack(imgs, axis=0)
masks = np.stack(plot_masks, axis=0) if len(plot_masks) > 0 else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if len(plot_kpts) > 0 else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if len(plot_boxes) > 0 else np.zeros(0, dtype=np.float32)
masks = np.stack(plot_masks, axis=0) if plot_masks else np.zeros(0, dtype=np.uint8)
kpts = np.concatenate(plot_kpts, axis=0) if plot_kpts else np.zeros((0, 51), dtype=np.float32)
boxes = xyxy2xywh(np.concatenate(plot_boxes, axis=0)) if plot_boxes else np.zeros(0, dtype=np.float32)
batch_idx = np.concatenate(batch_idx, axis=0)
cls = np.concatenate([np.array(c, dtype=np.int32) for c in cls], axis=0)

View file

@ -139,10 +139,9 @@ def get_window_obj(anno, windows, iof_thr=0.7):
label[:, 2::2] *= h
iofs = bbox_iof(label[:, 1:], windows)
# Unnormalized and misaligned coordinates
window_anns = [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))]
return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
else:
window_anns = [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))]
return window_anns
return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
@ -170,7 +169,7 @@ def crop_and_save(anno, windows, window_objs, im_dir, lb_dir):
name = Path(anno["filepath"]).stem
for i, window in enumerate(windows):
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
patch_im = im[y_start:y_stop, x_start:x_stop]
ph, pw = patch_im.shape[:2]
@ -271,7 +270,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
save_dir.mkdir(parents=True, exist_ok=True)
im_dir = Path(os.path.join(data_root, "images/test"))
assert im_dir.exists(), f"Can't find {str(im_dir)}, please check your data root."
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(str(im_dir / "*"))
for im_file in tqdm(im_files, total=len(im_files), desc="test"):
w, h = exif_size(Image.open(im_file))
@ -280,7 +279,7 @@ def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=[1.0]):
name = Path(im_file).stem
for window in windows:
x_start, y_start, x_stop, y_stop = window.tolist()
new_name = name + "__" + str(x_stop - x_start) + "__" + str(x_start) + "___" + str(y_start)
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
patch_im = im[y_start:y_stop, x_start:x_stop]
cv2.imwrite(os.path.join(str(save_dir), f"{new_name}.jpg"), patch_im)