ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
|
|||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
||||
assert n in (4, 9), "grid must be equal to 4 or 9."
|
||||
assert n in {4, 9}, "grid must be equal to 4 or 9."
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
|
|
@ -685,7 +685,7 @@ class RandomFlip:
|
|||
Default is 'horizontal'.
|
||||
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
||||
"""
|
||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import psutil
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
||||
from .utils import HELP_URL, IMG_FORMATS
|
||||
from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
|
|
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
|
|||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}"
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
if self.fraction < 1:
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ def merge_multi_segment(segments):
|
|||
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
||||
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
||||
# Deal with the first segment and the last one
|
||||
if i in [0, len(idx_list) - 1]:
|
||||
if i in {0, len(idx_list) - 1}:
|
||||
s.append(segments[i])
|
||||
else:
|
||||
idx = [0, idx[1] - idx[0]]
|
||||
|
|
@ -489,7 +489,7 @@ def merge_multi_segment(segments):
|
|||
|
||||
else:
|
||||
for i in range(len(idx_list) - 1, -1, -1):
|
||||
if i not in [0, len(idx_list) - 1]:
|
||||
if i not in {0, len(idx_list) - 1}:
|
||||
idx = idx_list[i]
|
||||
nidx = abs(idx[1] - idx[0])
|
||||
s.append(segments[i][nidx:])
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
|
|||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
total = len(self.im_files)
|
||||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
||||
raise ValueError(
|
||||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||
|
|
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
|||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in (-1, 0):
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||
if cache["msgs"]:
|
||||
|
|
@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
|
|||
value = values[i]
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
||||
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
|
|
@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in (-1, 0):
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import requests
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
|
||||
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ class LoadStreams:
|
|||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f"{i + 1}/{n}: {s}... "
|
||||
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
||||
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||||
s = get_best_youtube_url(s)
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
|
|
@ -291,8 +291,14 @@ class LoadImagesAndVideos:
|
|||
else:
|
||||
raise FileNotFoundError(f"{p} does not exist")
|
||||
|
||||
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
||||
# Define files as images or videos
|
||||
images, videos = [], []
|
||||
for f in files:
|
||||
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
|
||||
if suffix in IMG_FORMATS:
|
||||
images.append(f)
|
||||
elif suffix in VID_FORMATS:
|
||||
videos.append(f)
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.files = images + videos
|
||||
|
|
@ -307,10 +313,7 @@ class LoadImagesAndVideos:
|
|||
else:
|
||||
self.cap = None
|
||||
if self.nf == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No images or videos found in {p}. "
|
||||
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
)
|
||||
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
|
|||
- train
|
||||
- val
|
||||
"""
|
||||
assert split in ["train", "val"]
|
||||
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
|
||||
im_dir = Path(data_root) / "images" / split
|
||||
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
||||
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
|
|||
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
|
||||
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
|
|
@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
|
|||
exif = img.getexif()
|
||||
if exif:
|
||||
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
||||
if rotation in [6, 8]: # rotation 270 or 90
|
||||
if rotation in {6, 8}: # rotation 270 or 90
|
||||
s = s[1], s[0]
|
||||
return s
|
||||
|
||||
|
|
@ -79,8 +80,8 @@ def verify_image(args):
|
|||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
|
|
@ -105,8 +106,8 @@ def verify_image_label(args):
|
|||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
|
|
@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
else: # python script
|
||||
exec(s, {"yaml": data})
|
||||
dt = f"({round(time.time() - t, 1)}s)"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}\n")
|
||||
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||
|
||||
|
|
@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
|
|||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(("http:/", "https:/")):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
|
||||
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
|
||||
file = check_file(dataset)
|
||||
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue