ultralytics 8.0.233 improve Classify train augmentations (#4546)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
fatih 2024-01-04 00:17:10 +03:00 committed by GitHub
parent 6218b82072
commit 73dbb41920
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 253 additions and 108 deletions

View file

@ -14,9 +14,14 @@ from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
DEFAULT_MEAN = (0.0, 0.0, 0.0)
DEFAULT_STD = (1.0, 1.0, 1.0)
DEFAULT_CROP_FTACTION = 1.0
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
class BaseTransform:
@ -982,65 +987,144 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224, rect=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
"""Transforms to apply if albumentations not installed."""
def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
crop_fraction: float = DEFAULT_CROP_FTACTION,
):
"""
Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
Args:
size (int): image size
mean (tuple): mean values of RGB channels
std (tuple): std values of RGB channels
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
crop_fraction (float): fraction of image to crop. default is 1.0.
Returns:
T.Compose: torchvision transforms
"""
if isinstance(size, (tuple, list)):
assert len(size) == 2
scale_size = tuple([math.floor(x / crop_fraction) for x in size])
else:
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
else:
# resize shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
tfl += [T.ToTensor(), T.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
)]
return T.Compose(tfl)
# Classification augmentations train ---------------------------------------------------------------------------------------
def classify_augmentations(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.0,
auto_augment=None,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
):
"""
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
Args:
size (int): image size
scale (tuple): scale range of the image. default is (0.08, 1.0)
ratio (tuple): aspect ratio range of the image. default is (3./4., 4./3.)
mean (tuple): mean values of RGB channels
std (tuple): std values of RGB channels
hflip (float): probability of horizontal flip
vflip (float): probability of vertical flip
auto_augment (str): auto augmentation policy. can be 'randaugment', 'augmix', 'autoaugment' or None.
hsv_h (float): image HSV-Hue augmentation (fraction)
hsv_s (float): image HSV-Saturation augmentation (fraction)
hsv_v (float): image HSV-Value augmentation (fraction)
contrast (float): image contrast augmentation (fraction)
force_color_jitter (bool): force to apply color jitter even if auto augment is enabled
erasing (float): probability of random erasing
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
Returns:
T.Compose: torchvision transforms
"""
# Transforms to apply if albumentations not installed
if not isinstance(size, int):
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
transforms = [ClassifyLetterBox(size, auto=True) if rect else CenterCrop(size), ToTensor()]
if any(mean) or any(std):
transforms.append(T.Normalize(mean, std, inplace=True))
return T.Compose(transforms)
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.:
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
disable_color_jitter = False
if auto_augment:
assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter
def hsv2colorjitter(h, s, v):
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
return v, v, s, h
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False,
):
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
prefix = colorstr('albumentations: ')
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
check_version(A.__version__, '1.0.3', hard=True) # version requirement
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentations
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
if auto_augment == 'randaugment':
if TORCHVISION_0_11:
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if any((hsv_h, hsv_s, hsv_v)):
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
return A.Compose(T)
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f'{prefix}{e}')
elif auto_augment == 'augmix':
if TORCHVISION_0_13:
secondary_tfl += [T.AugMix(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
elif auto_augment == 'autoaugment':
if TORCHVISION_0_10:
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
else:
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
f'"augmix", "autoaugment" or None')
if not disable_color_jitter:
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
final_tfl = [
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
T.RandomErasing(p=erasing, inplace=True)]
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
# NOTE: keep this class for backward compatibility
class ClassifyLetterBox:
"""
YOLOv8 LetterBox class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
@ -1091,6 +1175,7 @@ class ClassifyLetterBox:
return im_out
# NOTE: keep this class for backward compatibility
class CenterCrop:
"""YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
T.Compose([CenterCrop(size), ToTensor()]).
@ -1117,6 +1202,7 @@ class CenterCrop:
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
# NOTE: keep this class for backward compatibility
class ToTensor:
"""YOLOv8 ToTensor class for image preprocessing, i.e., T.Compose([LetterBox(size), ToTensor()])."""