Resolve albumentations UserWarning (#13098)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d7bbfa42ef
commit
a357fac441
2 changed files with 68 additions and 11 deletions
|
|
@ -28,4 +28,4 @@ priytosh.revolution@live.com: priytosh-tripathi
|
||||||
shuizhuyuanluo@126.com: null
|
shuizhuyuanluo@126.com: null
|
||||||
stormsson@users.noreply.github.com: stormsson
|
stormsson@users.noreply.github.com: stormsson
|
||||||
xinwang614@gmail.com: GreatV
|
xinwang614@gmail.com: GreatV
|
||||||
andrei.kochin@intel: andrei-kochin
|
andrei.kochin@intel.com: andrei-kochin
|
||||||
|
|
|
||||||
|
|
@ -874,11 +874,56 @@ class Albumentations:
|
||||||
self.p = p
|
self.p = p
|
||||||
self.transform = None
|
self.transform = None
|
||||||
prefix = colorstr("albumentations: ")
|
prefix = colorstr("albumentations: ")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
|
|
||||||
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
check_version(A.__version__, "1.0.3", hard=True) # version requirement
|
||||||
|
|
||||||
|
# List of possible spatial transforms
|
||||||
|
spatial_transforms = {
|
||||||
|
"Affine",
|
||||||
|
"BBoxSafeRandomCrop",
|
||||||
|
"CenterCrop",
|
||||||
|
"CoarseDropout",
|
||||||
|
"Crop",
|
||||||
|
"CropAndPad",
|
||||||
|
"CropNonEmptyMaskIfExists",
|
||||||
|
"D4",
|
||||||
|
"ElasticTransform",
|
||||||
|
"Flip",
|
||||||
|
"GridDistortion",
|
||||||
|
"GridDropout",
|
||||||
|
"HorizontalFlip",
|
||||||
|
"Lambda",
|
||||||
|
"LongestMaxSize",
|
||||||
|
"MaskDropout",
|
||||||
|
"MixUp",
|
||||||
|
"Morphological",
|
||||||
|
"NoOp",
|
||||||
|
"OpticalDistortion",
|
||||||
|
"PadIfNeeded",
|
||||||
|
"Perspective",
|
||||||
|
"PiecewiseAffine",
|
||||||
|
"PixelDropout",
|
||||||
|
"RandomCrop",
|
||||||
|
"RandomCropFromBorders",
|
||||||
|
"RandomGridShuffle",
|
||||||
|
"RandomResizedCrop",
|
||||||
|
"RandomRotate90",
|
||||||
|
"RandomScale",
|
||||||
|
"RandomSizedBBoxSafeCrop",
|
||||||
|
"RandomSizedCrop",
|
||||||
|
"Resize",
|
||||||
|
"Rotate",
|
||||||
|
"SafeRotate",
|
||||||
|
"ShiftScaleRotate",
|
||||||
|
"SmallestMaxSize",
|
||||||
|
"Transpose",
|
||||||
|
"VerticalFlip",
|
||||||
|
"XYMasking",
|
||||||
|
} # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
T = [
|
T = [
|
||||||
A.Blur(p=0.01),
|
A.Blur(p=0.01),
|
||||||
|
|
@ -889,8 +934,14 @@ class Albumentations:
|
||||||
A.RandomGamma(p=0.0),
|
A.RandomGamma(p=0.0),
|
||||||
A.ImageCompression(quality_lower=75, p=0.0),
|
A.ImageCompression(quality_lower=75, p=0.0),
|
||||||
]
|
]
|
||||||
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
|
||||||
|
|
||||||
|
# Compose transforms
|
||||||
|
self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T)
|
||||||
|
self.transform = (
|
||||||
|
A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
|
||||||
|
if self.contains_spatial
|
||||||
|
else A.Compose(T)
|
||||||
|
)
|
||||||
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
|
|
@ -899,20 +950,26 @@ class Albumentations:
|
||||||
|
|
||||||
def __call__(self, labels):
|
def __call__(self, labels):
|
||||||
"""Generates object detections and returns a dictionary with detection results."""
|
"""Generates object detections and returns a dictionary with detection results."""
|
||||||
im = labels["img"]
|
if self.transform is None or random.random() > self.p:
|
||||||
cls = labels["cls"]
|
return labels
|
||||||
if len(cls):
|
|
||||||
labels["instances"].convert_bbox("xywh")
|
if self.contains_spatial:
|
||||||
labels["instances"].normalize(*im.shape[:2][::-1])
|
cls = labels["cls"]
|
||||||
bboxes = labels["instances"].bboxes
|
if len(cls):
|
||||||
# TODO: add supports of segments and keypoints
|
im = labels["img"]
|
||||||
if self.transform and random.random() < self.p:
|
labels["instances"].convert_bbox("xywh")
|
||||||
|
labels["instances"].normalize(*im.shape[:2][::-1])
|
||||||
|
bboxes = labels["instances"].bboxes
|
||||||
|
# TODO: add supports of segments and keypoints
|
||||||
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
|
||||||
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
|
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
|
||||||
labels["img"] = new["image"]
|
labels["img"] = new["image"]
|
||||||
labels["cls"] = np.array(new["class_labels"])
|
labels["cls"] = np.array(new["class_labels"])
|
||||||
bboxes = np.array(new["bboxes"], dtype=np.float32)
|
bboxes = np.array(new["bboxes"], dtype=np.float32)
|
||||||
labels["instances"].update(bboxes=bboxes)
|
labels["instances"].update(bboxes=bboxes)
|
||||||
|
else:
|
||||||
|
labels["img"] = self.transform(image=labels["img"])["image"] # transformed
|
||||||
|
|
||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue