From c7ceb84fb6085e767ed664460cd6abec9ced86d8 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 23 Aug 2023 02:28:58 +0200 Subject: [PATCH] `ultralytics 8.0.160` Classify dataset scanning and caching (#4502) --- ultralytics/__init__.py | 2 +- ultralytics/data/dataset.py | 94 ++++++++++++++++++----- ultralytics/data/utils.py | 25 ++++++ ultralytics/models/yolo/classify/train.py | 2 +- ultralytics/models/yolo/classify/val.py | 2 +- 5 files changed, 102 insertions(+), 23 deletions(-) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 6ecd8a28..6a3bfa70 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.159' +__version__ = '8.0.160' from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py index 6d620390..5318ca06 100644 --- a/ultralytics/data/dataset.py +++ b/ultralytics/data/dataset.py @@ -1,5 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license - +import contextlib from itertools import repeat from multiprocessing.pool import ThreadPool from pathlib import Path @@ -10,11 +10,14 @@ import torch import torchvision from tqdm import tqdm -from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable +from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms from .base import BaseDataset -from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label +from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label + +# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 +DATASET_CACHE_VERSION = '1.0.2' class YOLODataset(BaseDataset): @@ -29,7 +32,6 @@ class YOLODataset(BaseDataset): Returns: (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. """ - cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): self.use_segments = use_segments @@ -87,15 +89,7 @@ class YOLODataset(BaseDataset): x['hash'] = get_hash(self.label_files + self.im_files) x['results'] = nf, nm, ne, nc, len(self.im_files) x['msgs'] = msgs # warnings - x['version'] = self.cache_version # cache version - if is_dir_writeable(path.parent): - if path.exists(): - path.unlink() # remove *.cache file if exists - np.save(str(path), x) # save cache for next time - path.with_suffix('.cache.npy').rename(path) # remove .npy suffix - LOGGER.info(f'{self.prefix}New cache created: {path}') - else: - LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') + save_dataset_cache_file(self.prefix, path, x) return x def get_labels(self): @@ -103,11 +97,8 @@ class YOLODataset(BaseDataset): self.label_files = img2label_paths(self.im_files) cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') try: - import gc - gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 - cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict - gc.enable() - assert cache['version'] == self.cache_version # matches current version + cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file + assert cache['version'] == DATASET_CACHE_VERSION # matches current version assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash except (FileNotFoundError, AssertionError, AttributeError): cache, exists = self.cache_labels(cache_path), False # run cache ops @@ -116,7 +107,7 @@ class YOLODataset(BaseDataset): nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total if exists and LOCAL_RANK in (-1, 0): d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' - tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results + tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results if cache['msgs']: LOGGER.info('\n'.join(cache['msgs'])) # display warnings if nf == 0: # number of labels found @@ -216,7 +207,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. """ - def __init__(self, root, args, augment=False, cache=False): + def __init__(self, root, args, augment=False, cache=False, prefix=''): """ Initialize YOLO object with root, image size, augmentations, and cache settings. @@ -229,8 +220,10 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): super().__init__(root=root) if augment and args.fraction < 1.0: # reduce training fraction self.samples = self.samples[:round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f'{prefix}: ') if prefix else '' self.cache_ram = cache is True or cache == 'ram' self.cache_disk = cache == 'disk' + self.samples = self.verify_images() # filter out bad images self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im self.torch_transforms = classify_transforms(args.imgsz) self.album_transforms = classify_albumentations( @@ -266,6 +259,67 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): def __len__(self) -> int: return len(self.samples) + def verify_images(self): + """Verify all images in dataset.""" + desc = f'{self.prefix}Scanning {self.root}...' + path = Path(self.root).with_suffix('.cache') # *.cache file path + + with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): + cache = load_dataset_cache_file(path) # attempt to load a *.cache file + assert cache['version'] == DATASET_CACHE_VERSION # matches current version + assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total + if LOCAL_RANK in (-1, 0): + d = f'{desc} {nf} images, {nc} corrupt' + tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) + if cache['msgs']: + LOGGER.info('\n'.join(cache['msgs'])) # display warnings + return samples + + # Run scan if *.cache retrieval failed + nf, nc, msgs, samples, x = 0, 0, [], [], {} + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(func=verify_image, iterable=zip([x[0] for x in self.samples], repeat(self.prefix))) + pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT) + for im_file, nf_f, nc_f, msg in pbar: + if nf_f: + samples.append((im_file, nf)) + if msg: + msgs.append(msg) + nf += nf_f + nc += nc_f + pbar.desc = f'{desc} {nf} images, {nc} corrupt' + pbar.close() + if msgs: + LOGGER.info('\n'.join(msgs)) + x['hash'] = get_hash([x[0] for x in self.samples]) + x['results'] = nf, nc, len(samples), samples + x['msgs'] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x) + return samples + + +def load_dataset_cache_file(path): + """Load an Ultralytics *.cache dictionary from path.""" + import gc + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 + cache = np.load(str(path), allow_pickle=True).item() # load dict + gc.enable() + return cache + + +def save_dataset_cache_file(prefix, path, x): + """Save an Ultralytics dataset *.cache dictionary x to path.""" + x['version'] = DATASET_CACHE_VERSION # add cache version + if is_dir_writeable(path.parent): + if path.exists(): + path.unlink() # remove *.cache file if exists + np.save(str(path), x) # save cache for next time + path.with_suffix('.cache.npy').rename(path) # remove .npy suffix + LOGGER.info(f'{prefix}New cache created: {path}') + else: + LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') + # TODO: support semantic segmentation class SemanticDataset(BaseDataset): diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index a2d04c0e..c8d5e99e 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -57,6 +57,31 @@ def exif_size(img: Image.Image): return s +def verify_image(args): + """Verify one image.""" + im_file, prefix = args + # Number (found, corrupt), message + nf, nc, msg = 0, 0, '' + try: + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + shape = (shape[1], shape[0]) # hw + assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels' + assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}' + if im.format.lower() in ('jpg', 'jpeg'): + with open(im_file, 'rb') as f: + f.seek(-2, 2) + if f.read() != b'\xff\xd9': # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100) + msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved' + nf = 1 + except Exception as e: + nc = 1 + msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}' + return im_file, nf, nc, msg + + def verify_image_label(args): """Verify one image-label pair.""" im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py index b09b5201..c9393ac6 100644 --- a/ultralytics/models/yolo/classify/train.py +++ b/ultralytics/models/yolo/classify/train.py @@ -79,7 +79,7 @@ class ClassificationTrainer(BaseTrainer): return ckpt def build_dataset(self, img_path, mode='train', batch=None): - return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train') + return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train', prefix=mode) def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py index e9de4fdd..80606292 100644 --- a/ultralytics/models/yolo/classify/val.py +++ b/ultralytics/models/yolo/classify/val.py @@ -77,7 +77,7 @@ class ClassificationValidator(BaseValidator): return self.metrics.results_dict def build_dataset(self, img_path): - return ClassificationDataset(root=img_path, args=self.args, augment=False) + return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split) def get_dataloader(self, dataset_path, batch_size): """Builds and returns a data loader for classification tasks with given parameters."""