Fix error on TensorRT export with float workspace value (#17352)
This commit is contained in:
parent
603fa84774
commit
d0abd95f95
1 changed files with 1 additions and 1 deletions
|
|
@ -791,7 +791,7 @@ class Exporter:
|
|||
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
||||
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
|
||||
max_shape = (*shape[:2], *(int(max(1, self.args.workspace) * d) for d in shape[2:])) # max input shape
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue