Fix error on TensorRT export with float workspace value (#17352)

This commit is contained in:
Mohammed Yasin 2024-11-05 17:48:33 +08:00 committed by GitHub
parent 603fa84774
commit d0abd95f95
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -791,7 +791,7 @@ class Exporter:
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
max_shape = (*shape[:2], *(int(max(1, self.args.workspace) * d) for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
config.add_optimization_profile(profile)