Improved float workspace arg for TRT exports (#9407)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
4a7ccba0af
commit
03d0ffd9f5
2 changed files with 2 additions and 5 deletions
|
|
@ -94,7 +94,7 @@ CLI_HELP_MSG = f"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Define keys for arg type checks
|
# Define keys for arg type checks
|
||||||
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"}
|
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
|
||||||
CFG_FRACTION_KEYS = {
|
CFG_FRACTION_KEYS = {
|
||||||
"dropout",
|
"dropout",
|
||||||
"iou",
|
"iou",
|
||||||
|
|
@ -132,7 +132,6 @@ CFG_INT_KEYS = {
|
||||||
"max_det",
|
"max_det",
|
||||||
"vid_stride",
|
"vid_stride",
|
||||||
"line_width",
|
"line_width",
|
||||||
"workspace",
|
|
||||||
"nbs",
|
"nbs",
|
||||||
"save_period",
|
"save_period",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -675,9 +675,7 @@ class Exporter:
|
||||||
|
|
||||||
builder = trt.Builder(logger)
|
builder = trt.Builder(logger)
|
||||||
config = builder.create_builder_config()
|
config = builder.create_builder_config()
|
||||||
config.max_workspace_size = self.args.workspace * 1 << 30
|
config.max_workspace_size = int(self.args.workspace * (1 << 30))
|
||||||
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
|
|
||||||
|
|
||||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||||
network = builder.create_network(flag)
|
network = builder.create_network(flag)
|
||||||
parser = trt.OnnxParser(network, logger)
|
parser = trt.OnnxParser(network, logger)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue