Add max_dim==2 argument to check_imgsz() (#789)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-04 01:48:44 +04:00 committed by GitHub
parent 5a80ad98db
commit 0d182e80f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 52 deletions

View file

@ -202,7 +202,7 @@ class BaseTrainer:
self.model = DDP(self.model, device_ids=[rank])
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
# Batch size
if self.batch_size == -1:
if RANK == -1: # single-GPU only, estimate best batch size