ultralytics 8.3.21 NVIDIA DLA export support (#16449)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Lakshantha Dissanayake <lakshanthad@yahoo.com> Co-authored-by: Lakshantha <lakshantha@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
parent
b8fbee3a97
commit
8f0a94409f
4 changed files with 66 additions and 5 deletions
|
|
@ -194,6 +194,11 @@ class Exporter:
|
|||
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
|
||||
|
||||
# Device
|
||||
dla = None
|
||||
if fmt == "engine" and "dla" in self.args.device:
|
||||
dla = self.args.device.split(":")[-1]
|
||||
assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}."
|
||||
self.args.device = "0"
|
||||
if fmt == "engine" and self.args.device is None:
|
||||
LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0")
|
||||
self.args.device = "0"
|
||||
|
|
@ -309,7 +314,7 @@ class Exporter:
|
|||
if jit or ncnn: # TorchScript
|
||||
f[0], _ = self.export_torchscript()
|
||||
if engine: # TensorRT required before ONNX
|
||||
f[1], _ = self.export_engine()
|
||||
f[1], _ = self.export_engine(dla=dla)
|
||||
if onnx: # ONNX
|
||||
f[2], _ = self.export_onnx()
|
||||
if xml: # OpenVINO
|
||||
|
|
@ -682,7 +687,7 @@ class Exporter:
|
|||
return f, ct_model
|
||||
|
||||
@try_export
|
||||
def export_engine(self, prefix=colorstr("TensorRT:")):
|
||||
def export_engine(self, dla=None, prefix=colorstr("TensorRT:")):
|
||||
"""YOLO TensorRT export https://developer.nvidia.com/tensorrt."""
|
||||
assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'"
|
||||
f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016
|
||||
|
|
@ -717,6 +722,20 @@ class Exporter:
|
|||
network = builder.create_network(flag)
|
||||
half = builder.platform_has_fast_fp16 and self.args.half
|
||||
int8 = builder.platform_has_fast_int8 and self.args.int8
|
||||
|
||||
# Optionally switch to DLA if enabled
|
||||
if dla is not None:
|
||||
if not IS_JETSON:
|
||||
raise ValueError("DLA is only available on NVIDIA Jetson devices")
|
||||
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
|
||||
if not self.args.half and not self.args.int8:
|
||||
raise ValueError(
|
||||
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
|
||||
)
|
||||
config.default_device_type = trt.DeviceType.DLA
|
||||
config.DLA_core = int(dla)
|
||||
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
|
||||
|
||||
# Read ONNX file
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(f_onnx):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue