diff --git a/tests/test_cuda.py b/tests/test_cuda.py index ed08f700..61954a2f 100644 --- a/tests/test_cuda.py +++ b/tests/test_cuda.py @@ -19,6 +19,13 @@ def test_checks(): assert torch.cuda.is_available() == CUDA_IS_AVAILABLE assert torch.cuda.device_count() == CUDA_DEVICE_COUNT +@pytest.mark.slow +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_export_engine(): + """Test exporting the YOLO model to NVIDIA TensorRT format.""" + f = YOLO(MODEL).export(format="engine", device=0) + YOLO(f)(BUS, device=0) + @pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") def test_train(): diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index c0a192bf..4125944c 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.45" +__version__ = "8.1.46" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index c5385e80..54edbf57 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -658,6 +658,7 @@ class Exporter: def export_engine(self, prefix=colorstr("TensorRT:")): """YOLOv8 TensorRT export https://developer.nvidia.com/tensorrt.""" assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" + self.args.simplify = True f_onnx, _ = self.export_onnx() # run before trt import https://github.com/ultralytics/ultralytics/issues/7016 try: @@ -666,12 +667,10 @@ class Exporter: if LINUX: check_requirements("nvidia-tensorrt", cmds="-U --index-url https://pypi.ngc.nvidia.com") import tensorrt as trt # noqa - check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0 - self.args.simplify = True - LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" f = self.file.with_suffix(".engine") # TensorRT engine file logger = trt.Logger(trt.Logger.INFO) @@ -680,7 +679,11 @@ class Exporter: builder = trt.Builder(logger) config = builder.create_builder_config() - config.max_workspace_size = int(self.args.workspace * (1 << 30)) + workspace = int(self.args.workspace * (1 << 30)) + if is_trt10: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) + else: # TensorRT versions 7, 8 + config.max_workspace_size = workspace flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) @@ -699,27 +702,31 @@ class Exporter: if shape[0] <= 1: LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'") profile = builder.create_optimization_profile() + min_shape = (1, shape[1], 32, 32) # minimum input shape + opt_shape = (max(1, shape[0] // 2), *shape[1:]) # optimal input shape + max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape for inp in inputs: - profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape) + profile.set_shape(inp.name, min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) - LOGGER.info( - f"{prefix} building FP{16 if builder.platform_has_fast_fp16 and self.args.half else 32} engine as {f}" - ) - if builder.platform_has_fast_fp16 and self.args.half: + half = builder.platform_has_fast_fp16 and self.args.half + LOGGER.info(f"{prefix} building FP{16 if half else 32} engine as {f}") + if half: config.set_flag(trt.BuilderFlag.FP16) + # Free CUDA memory del self.model torch.cuda.empty_cache() # Write file - with builder.build_engine(network, config) as engine, open(f, "wb") as t: + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: # Metadata meta = json.dumps(self.metadata) t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) t.write(meta.encode()) # Model - t.write(engine.serialize()) + t.write(engine if is_trt10 else engine.serialize()) return f, None diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 0b473bf3..29709cdb 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -234,23 +234,47 @@ class AutoBackend(nn.Module): meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata model = runtime.deserialize_cuda_engine(f.read()) # read engine - context = model.create_execution_context() + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + bindings = OrderedDict() output_names = [] fp16 = False # default updated below dynamic = False - for i in range(model.num_bindings): - name = model.get_binding_name(i) - dtype = trt.nptype(model.get_binding_dtype(i)) - if model.binding_is_input(i): - if -1 in tuple(model.get_binding_shape(i)): # dynamic - dynamic = True - context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2])) - if dtype == np.float16: - fp16 = True - else: # output - output_names.append(name) - shape = tuple(context.get_binding_shape(i)) + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) @@ -463,13 +487,20 @@ class AutoBackend(nn.Module): # TensorRT elif self.engine: - if self.dynamic and im.shape != self.bindings["images"].shape: - i = self.model.get_binding_index("images") - self.context.set_binding_shape(i, im.shape) # reshape if dynamic - self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) - for name in self.output_names: - i = self.model.get_binding_index(name) - self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + if self.dynamic or im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + s = self.bindings["images"].shape assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" self.binding_addrs["images"] = int(im.data_ptr())