Add max_dim==2 argument to check_imgsz() (#789)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-04 01:48:44 +04:00 committed by GitHub
parent 5a80ad98db
commit 0d182e80f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 52 deletions

View file

@ -40,7 +40,7 @@ def is_ascii(s) -> bool:
return all(ord(c) < 128 for c in s)
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
"""
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
@ -66,6 +66,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
# Apply max_dim
if max_dim == 1:
LOGGER.warning(f"WARNING ⚠️ 'train' and 'val' imgsz types must be integer, updating to 'imgsz={max(imgsz)}'. "
f"'predict' and 'export' imgsz may be list or integer, "
f"i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]