Add max_dim==2 argument to check_imgsz() (#789)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
parent
5a80ad98db
commit
0d182e80f1
11 changed files with 96 additions and 52 deletions
|
|
@ -40,7 +40,7 @@ def is_ascii(s) -> bool:
|
|||
return all(ord(c) < 128 for c in s)
|
||||
|
||||
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
"""
|
||||
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
|
@ -66,6 +66,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0):
|
|||
raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'")
|
||||
|
||||
# Apply max_dim
|
||||
if max_dim == 1:
|
||||
LOGGER.warning(f"WARNING ⚠️ 'train' and 'val' imgsz types must be integer, updating to 'imgsz={max(imgsz)}'. "
|
||||
f"'predict' and 'export' imgsz may be list or integer, "
|
||||
f"i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'")
|
||||
imgsz = [max(imgsz)]
|
||||
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue