ultralytics 8.0.233 improve Classify train augmentations (#4546)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
6218b82072
commit
73dbb41920
13 changed files with 253 additions and 108 deletions
|
|
@ -14,9 +14,14 @@ from ultralytics.utils.checks import check_version
|
|||
from ultralytics.utils.instance import Instances
|
||||
from ultralytics.utils.metrics import bbox_ioa
|
||||
from ultralytics.utils.ops import segment2box
|
||||
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
|
||||
|
||||
from .utils import polygons2masks, polygons2masks_overlap
|
||||
|
||||
DEFAULT_MEAN = (0.0, 0.0, 0.0)
|
||||
DEFAULT_STD = (1.0, 1.0, 1.0)
|
||||
DEFAULT_CROP_FTACTION = 1.0
|
||||
|
||||
|
||||
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
|
||||
class BaseTransform:
|
||||
|
|
@ -982,65 +987,144 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
|
|||
|
||||
|
||||
# Classification augmentations -----------------------------------------------------------------------------------------
|
||||
def classify_transforms(size=224, rect=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
|
||||
"""Transforms to apply if albumentations not installed."""
|
||||
def classify_transforms(
|
||||
size=224,
|
||||
mean=DEFAULT_MEAN,
|
||||
std=DEFAULT_STD,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
crop_fraction: float = DEFAULT_CROP_FTACTION,
|
||||
):
|
||||
"""
|
||||
Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
|
||||
|
||||
Args:
|
||||
size (int): image size
|
||||
mean (tuple): mean values of RGB channels
|
||||
std (tuple): std values of RGB channels
|
||||
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
|
||||
crop_fraction (float): fraction of image to crop. default is 1.0.
|
||||
|
||||
Returns:
|
||||
T.Compose: torchvision transforms
|
||||
"""
|
||||
|
||||
if isinstance(size, (tuple, list)):
|
||||
assert len(size) == 2
|
||||
scale_size = tuple([math.floor(x / crop_fraction) for x in size])
|
||||
else:
|
||||
scale_size = math.floor(size / crop_fraction)
|
||||
scale_size = (scale_size, scale_size)
|
||||
|
||||
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
|
||||
if scale_size[0] == scale_size[1]:
|
||||
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
||||
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
|
||||
else:
|
||||
# resize shortest edge to matching target dim for non-square target
|
||||
tfl = [T.Resize(scale_size)]
|
||||
tfl += [T.CenterCrop(size)]
|
||||
|
||||
tfl += [T.ToTensor(), T.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std),
|
||||
)]
|
||||
|
||||
return T.Compose(tfl)
|
||||
|
||||
|
||||
# Classification augmentations train ---------------------------------------------------------------------------------------
|
||||
def classify_augmentations(
|
||||
size=224,
|
||||
mean=DEFAULT_MEAN,
|
||||
std=DEFAULT_STD,
|
||||
scale=None,
|
||||
ratio=None,
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
auto_augment=None,
|
||||
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
|
||||
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
force_color_jitter=False,
|
||||
erasing=0.,
|
||||
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
|
||||
):
|
||||
"""
|
||||
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
|
||||
|
||||
Args:
|
||||
size (int): image size
|
||||
scale (tuple): scale range of the image. default is (0.08, 1.0)
|
||||
ratio (tuple): aspect ratio range of the image. default is (3./4., 4./3.)
|
||||
mean (tuple): mean values of RGB channels
|
||||
std (tuple): std values of RGB channels
|
||||
hflip (float): probability of horizontal flip
|
||||
vflip (float): probability of vertical flip
|
||||
auto_augment (str): auto augmentation policy. can be 'randaugment', 'augmix', 'autoaugment' or None.
|
||||
hsv_h (float): image HSV-Hue augmentation (fraction)
|
||||
hsv_s (float): image HSV-Saturation augmentation (fraction)
|
||||
hsv_v (float): image HSV-Value augmentation (fraction)
|
||||
contrast (float): image contrast augmentation (fraction)
|
||||
force_color_jitter (bool): force to apply color jitter even if auto augment is enabled
|
||||
erasing (float): probability of random erasing
|
||||
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
|
||||
|
||||
Returns:
|
||||
T.Compose: torchvision transforms
|
||||
"""
|
||||
# Transforms to apply if albumentations not installed
|
||||
if not isinstance(size, int):
|
||||
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
|
||||
transforms = [ClassifyLetterBox(size, auto=True) if rect else CenterCrop(size), ToTensor()]
|
||||
if any(mean) or any(std):
|
||||
transforms.append(T.Normalize(mean, std, inplace=True))
|
||||
return T.Compose(transforms)
|
||||
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
|
||||
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
|
||||
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
|
||||
if hflip > 0.:
|
||||
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
|
||||
if vflip > 0.:
|
||||
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
|
||||
|
||||
secondary_tfl = []
|
||||
disable_color_jitter = False
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
# color jitter is typically disabled if AA/RA on,
|
||||
# this allows override without breaking old hparm cfgs
|
||||
disable_color_jitter = not force_color_jitter
|
||||
|
||||
def hsv2colorjitter(h, s, v):
|
||||
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
|
||||
return v, v, s, h
|
||||
|
||||
|
||||
def classify_albumentations(
|
||||
augment=True,
|
||||
size=224,
|
||||
scale=(0.08, 1.0),
|
||||
hflip=0.5,
|
||||
vflip=0.0,
|
||||
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
|
||||
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v=0.4, # image HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False,
|
||||
):
|
||||
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
|
||||
prefix = colorstr('albumentations: ')
|
||||
try:
|
||||
import albumentations as A
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
|
||||
check_version(A.__version__, '1.0.3', hard=True) # version requirement
|
||||
if augment: # Resize and crop
|
||||
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
|
||||
if auto_aug:
|
||||
# TODO: implement AugMix, AutoAug & RandAug in albumentations
|
||||
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
|
||||
if auto_augment == 'randaugment':
|
||||
if TORCHVISION_0_11:
|
||||
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
|
||||
else:
|
||||
if hflip > 0:
|
||||
T += [A.HorizontalFlip(p=hflip)]
|
||||
if vflip > 0:
|
||||
T += [A.VerticalFlip(p=vflip)]
|
||||
if any((hsv_h, hsv_s, hsv_v)):
|
||||
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
|
||||
else: # Use fixed crop for eval set (reproducibility)
|
||||
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
|
||||
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
|
||||
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
|
||||
return A.Compose(T)
|
||||
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
except Exception as e:
|
||||
LOGGER.info(f'{prefix}{e}')
|
||||
elif auto_augment == 'augmix':
|
||||
if TORCHVISION_0_13:
|
||||
secondary_tfl += [T.AugMix(interpolation=interpolation)]
|
||||
else:
|
||||
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
|
||||
|
||||
elif auto_augment == 'autoaugment':
|
||||
if TORCHVISION_0_10:
|
||||
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
|
||||
else:
|
||||
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
|
||||
f'"augmix", "autoaugment" or None')
|
||||
|
||||
if not disable_color_jitter:
|
||||
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
|
||||
|
||||
final_tfl = [
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
|
||||
T.RandomErasing(p=erasing, inplace=True)]
|
||||
|
||||
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||
|
||||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class ClassifyLetterBox:
|
||||
"""
|
||||
YOLOv8 LetterBox class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
|
||||
|
|
@ -1091,6 +1175,7 @@ class ClassifyLetterBox:
|
|||
return im_out
|
||||
|
||||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class CenterCrop:
|
||||
"""YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
|
||||
T.Compose([CenterCrop(size), ToTensor()]).
|
||||
|
|
@ -1117,6 +1202,7 @@ class CenterCrop:
|
|||
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
|
||||
# NOTE: keep this class for backward compatibility
|
||||
class ToTensor:
|
||||
"""YOLOv8 ToTensor class for image preprocessing, i.e., T.Compose([LetterBox(size), ToTensor()])."""
|
||||
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ import cv2
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
|
||||
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
|
||||
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
|
||||
from .base import BaseDataset
|
||||
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
|
||||
|
||||
|
|
@ -225,19 +226,17 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
self.cache_disk = cache == 'disk'
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
|
||||
self.torch_transforms = classify_transforms(args.imgsz, rect=args.rect)
|
||||
self.album_transforms = classify_albumentations(
|
||||
augment=augment,
|
||||
size=args.imgsz,
|
||||
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
|
||||
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
|
||||
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
|
||||
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
|
||||
std=(1.0, 1.0, 1.0), # IMAGENET_STD
|
||||
auto_aug=False) if augment else None
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = classify_augmentations(size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v) if augment else classify_transforms(
|
||||
size=args.imgsz, crop_fraction=args.crop_fraction)
|
||||
|
||||
def __getitem__(self, i):
|
||||
"""Returns subset of data and targets corresponding to given indices."""
|
||||
|
|
@ -250,10 +249,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
if self.album_transforms:
|
||||
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
|
||||
else:
|
||||
sample = self.torch_transforms(im)
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {'img': sample, 'cls': j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue