ultralytics 8.1.46 add TensorRT 10 support (#9516)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 九是否随意的称呼 <1069679911@qq.com>
This commit is contained in:
parent
ea03db9984
commit
4ffd6ee6d7
4 changed files with 77 additions and 32 deletions
|
|
@ -658,6 +658,7 @@ class Exporter:
|
|||
def export_engine(self, prefix=colorstr("TensorRT:")):
|
||||
"""YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt."""
|
||||
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
self.args.simplify = True
|
||||
f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016
|
||||
|
||||
try:
|
||||
|
|
@ -666,12 +667,10 @@ class Exporter:
|
|||
if LINUX:
|
||||
check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com")
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
|
||||
|
||||
self.args.simplify = True
|
||||
|
||||
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
||||
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
|
||||
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
||||
f = self.file.with_suffix(".engine") # TensorRT engine file
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
|
|
@ -680,7 +679,11 @@ class Exporter:
|
|||
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
config.max_workspace_size = int(self.args.workspace * (1 << 30))
|
||||
workspace = int(self.args.workspace * (1 << 30))
|
||||
if is_trt10:
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
|
||||
else: # TensorRT versions 7, 8
|
||||
config.max_workspace_size = workspace
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
|
|
@ -699,27 +702,31 @@ class Exporter:
|
|||
if shape[0] <= 1:
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
||||
opt_shape = (max(1, shape[0] // 2), *shape[1:]) # optimal input shape
|
||||
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape)
|
||||
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
LOGGER.info(
|
||||
f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}"
|
||||
)
|
||||
if builder.platform_has_fast_fp16 and self.args.half:
|
||||
half = builder.platform_has_fast_fp16 and self.args.half
|
||||
LOGGER.info(f"{prefix} building FP{16 if half else 32} engine as {f}")
|
||||
if half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Free CUDA memory
|
||||
del self.model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Write file
|
||||
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
|
||||
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
||||
with build(network, config) as engine, open(f, "wb") as t:
|
||||
# Metadata
|
||||
meta = json.dumps(self.metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine.serialize())
|
||||
t.write(engine if is_trt10 else engine.serialize())
|
||||
|
||||
return f, None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue