ultralytics 8.1.43 40% faster ultralytics imports (#9547)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-04-05 15:29:09 +02:00 committed by GitHub
parent 99c61d6f7b
commit a2628657a1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 240 additions and 225 deletions

View file

@ -8,7 +8,7 @@ from typing import Tuple, Union
import cv2
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from ultralytics.utils import LOGGER, colorstr
from ultralytics.utils.checks import check_version
@ -167,8 +167,8 @@ class BaseMixTransform:
text2id = {text: i for i, text in enumerate(mix_texts)}
for label in [labels] + labels["mix_labels"]:
for i, l in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(l)]
for i, cls in enumerate(label["cls"].squeeze(-1).tolist()):
text = label["texts"][int(cls)]
label["cls"][i] = text2id[tuple(text)]
label["texts"] = mix_texts
return labels
@ -1133,7 +1133,7 @@ def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation=Image.BILINEAR,
crop_fraction: float = DEFAULT_CROP_FTACTION,
):
"""
@ -1149,6 +1149,7 @@ def classify_transforms(
Returns:
(T.Compose): torchvision transforms
"""
import torchvision.transforms as T # scope for faster 'import ultralytics'
if isinstance(size, (tuple, list)):
assert len(size) == 2
@ -1157,12 +1158,12 @@ def classify_transforms(
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
else:
# resize shortest edge to matching target dim for non-square target
# Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
@ -1192,7 +1193,7 @@ def classify_augmentations(
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.0,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
interpolation=Image.BILINEAR,
):
"""
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
@ -1216,7 +1217,9 @@ def classify_augmentations(
Returns:
(T.Compose): torchvision transforms
"""
# Transforms to apply if albumentations not installed
# Transforms to apply if Albumentations not installed
import torchvision.transforms as T # scope for faster 'import ultralytics'
if not isinstance(size, int):
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range