ultralytics 8.1.40 search in Python sets {} for speed (#9450)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-04-01 00:16:52 +02:00 committed by GitHub
parent 30484d5925
commit ea527507fe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 97 additions and 93 deletions

View file

@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
assert n in (4, 9), "grid must be equal to 4 or 9."
assert n in {4, 9}, "grid must be equal to 4 or 9."
super().__init__(dataset=dataset, p=p)
self.dataset = dataset
self.imgsz = imgsz
@ -685,7 +685,7 @@ class RandomFlip:
Default is 'horizontal'.
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
"""
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
assert 0 <= p <= 1.0
self.p = p

View file

@ -15,7 +15,7 @@ import psutil
from torch.utils.data import Dataset
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
from .utils import HELP_URL, IMG_FORMATS
from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
class BaseDataset(Dataset):
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
assert im_files, f"{self.prefix}No images found in {img_path}"
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
except Exception as e:
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
if self.fraction < 1:

View file

@ -481,7 +481,7 @@ def merge_multi_segment(segments):
segments[i] = np.roll(segments[i], -idx[0], axis=0)
segments[i] = np.concatenate([segments[i], segments[i][:1]])
# Deal with the first segment and the last one
if i in [0, len(idx_list) - 1]:
if i in {0, len(idx_list) - 1}:
s.append(segments[i])
else:
idx = [0, idx[1] - idx[0]]
@ -489,7 +489,7 @@ def merge_multi_segment(segments):
else:
for i in range(len(idx_list) - 1, -1, -1):
if i not in [0, len(idx_list) - 1]:
if i not in {0, len(idx_list) - 1}:
idx = idx_list[i]
nidx = abs(idx[1] - idx[0])
s.append(segments[i][nidx:])

View file

@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
total = len(self.im_files)
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
raise ValueError(
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
# Display cache
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
if exists and LOCAL_RANK in (-1, 0):
if exists and LOCAL_RANK in {-1, 0}:
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
if cache["msgs"]:
@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
value = values[i]
if k == "img":
value = torch.stack(value, 0)
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
value = torch.cat(value, 0)
new_batch[k] = value
new_batch["batch_idx"] = list(new_batch["batch_idx"])
@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
if LOCAL_RANK in (-1, 0):
if LOCAL_RANK in {-1, 0}:
d = f"{desc} {nf} images, {nc} corrupt"
TQDM(None, desc=d, total=n, initial=n)
if cache["msgs"]:

View file

@ -15,7 +15,7 @@ import requests
import torch
from PIL import Image
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
from ultralytics.utils.checks import check_requirements
@ -83,7 +83,7 @@ class LoadStreams:
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f"{i + 1}/{n}: {s}... "
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
s = get_best_youtube_url(s)
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
@ -291,8 +291,14 @@ class LoadImagesAndVideos:
else:
raise FileNotFoundError(f"{p} does not exist")
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
# Define files as images or videos
images, videos = [], []
for f in files:
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
if suffix in IMG_FORMATS:
images.append(f)
elif suffix in VID_FORMATS:
videos.append(f)
ni, nv = len(images), len(videos)
self.files = images + videos
@ -307,10 +313,7 @@ class LoadImagesAndVideos:
else:
self.cap = None
if self.nf == 0:
raise FileNotFoundError(
f"No images or videos found in {p}. "
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""

View file

@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
- train
- val
"""
assert split in ["train", "val"]
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
im_dir = Path(data_root) / "images" / split
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
im_files = glob(str(Path(data_root) / "images" / split / "*"))

View file

@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
def img2label_paths(img_paths):
@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
exif = img.getexif()
if exif:
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
if rotation in [6, 8]: # rotation 270 or 90
if rotation in {6, 8}: # rotation 270 or 90
s = s[1], s[0]
return s
@ -79,8 +80,8 @@ def verify_image(args):
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
if im.format.lower() in {"jpg", "jpeg"}:
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
@ -105,8 +106,8 @@ def verify_image_label(args):
shape = exif_size(im) # image size
shape = (shape[1], shape[0]) # hw
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
if im.format.lower() in ("jpg", "jpeg"):
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
if im.format.lower() in {"jpg", "jpeg"}:
with open(im_file, "rb") as f:
f.seek(-2, 2)
if f.read() != b"\xff\xd9": # corrupt JPEG
@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
else: # python script
exec(s, {"yaml": data})
dt = f"({round(time.time() - t, 1)}s)"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt}"
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt}"
LOGGER.info(f"Dataset download {s}\n")
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
# Download (optional if dataset=https://file.zip is passed directly)
if str(dataset).startswith(("http:/", "https:/")):
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
file = check_file(dataset)
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)