ultralytics 8.0.233 improve Classify train augmentations (#4546)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
parent
6218b82072
commit
73dbb41920
13 changed files with 253 additions and 108 deletions
|
|
@ -505,6 +505,48 @@ def test_hub():
|
|||
smart_request('GET', 'http://github.com', progress=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image():
|
||||
return cv2.imread(str(SOURCE))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'auto_augment, erasing, force_color_jitter',
|
||||
[
|
||||
(None, 0.0, False),
|
||||
('randaugment', 0.5, True),
|
||||
('augmix', 0.2, False),
|
||||
('autoaugment', 0.0, True), ],
|
||||
)
|
||||
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
|
||||
import torchvision.transforms as T
|
||||
|
||||
from ultralytics.data.augment import classify_augmentations
|
||||
|
||||
transform = classify_augmentations(
|
||||
size=224,
|
||||
mean=(0.5, 0.5, 0.5),
|
||||
std=(0.5, 0.5, 0.5),
|
||||
scale=(0.08, 1.0),
|
||||
ratio=(3.0 / 4.0, 4.0 / 3.0),
|
||||
hflip=0.5,
|
||||
vflip=0.5,
|
||||
auto_augment=auto_augment,
|
||||
hsv_h=0.015,
|
||||
hsv_s=0.4,
|
||||
hsv_v=0.4,
|
||||
force_color_jitter=force_color_jitter,
|
||||
erasing=erasing,
|
||||
interpolation=T.InterpolationMode.BILINEAR,
|
||||
)
|
||||
|
||||
transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
|
||||
|
||||
assert transformed_image.shape == (3, 224, 224)
|
||||
assert torch.is_tensor(transformed_image)
|
||||
assert transformed_image.dtype == torch.float32
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
|
||||
def test_model_tune():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue