Exported model batch size validation fix (#14845)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8648572809
commit
db4c43bafb
1 changed files with 2 additions and 2 deletions
|
|
@ -136,8 +136,8 @@ class BaseValidator:
|
|||
if engine:
|
||||
self.args.batch = model.batch_size
|
||||
elif not pt and not jit:
|
||||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
||||
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
||||
|
||||
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue