Improved float workspace arg for TRT exports (#9407)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Jiacong Fang 2024-03-29 23:22:32 +08:00 committed by GitHub
parent 4a7ccba0af
commit 03d0ffd9f5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 2 additions and 5 deletions

View file

@ -94,7 +94,7 @@ CLI_HELP_MSG = f"""
""" """
# Define keys for arg type checks # Define keys for arg type checks
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"} CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
CFG_FRACTION_KEYS = { CFG_FRACTION_KEYS = {
"dropout", "dropout",
"iou", "iou",
@ -132,7 +132,6 @@ CFG_INT_KEYS = {
"max_det", "max_det",
"vid_stride", "vid_stride",
"line_width", "line_width",
"workspace",
"nbs", "nbs",
"save_period", "save_period",
} }

View file

@ -675,9 +675,7 @@ class Exporter:
builder = trt.Builder(logger) builder = trt.Builder(logger)
config = builder.create_builder_config() config = builder.create_builder_config()
config.max_workspace_size = self.args.workspace * 1 << 30 config.max_workspace_size = int(self.args.workspace * (1 << 30))
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag) network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger) parser = trt.OnnxParser(network, logger)