ultralytics 8.2.11 new TensorRT INT8 export feature (#10165)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1d9745182d
commit
fcfc44ea9c
15 changed files with 601 additions and 176 deletions
|
|
@ -200,6 +200,8 @@ class Exporter:
|
|||
self.args.half = False
|
||||
assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one."
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size
|
||||
if self.args.int8 and engine:
|
||||
self.args.dynamic = True # enforce dynamic to export TensorRT INT8; ensures ONNX is dynamic
|
||||
if self.args.optimize:
|
||||
assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False"
|
||||
assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'"
|
||||
|
|
@ -349,12 +351,12 @@ class Exporter:
|
|||
task=self.model.task,
|
||||
imgsz=self.imgsz[0],
|
||||
augment=False,
|
||||
batch_size=self.args.batch,
|
||||
batch_size=self.args.batch * 2, # NOTE TensorRT INT8 calibration should use 2x batch size
|
||||
)
|
||||
n = len(dataset)
|
||||
if n < 300:
|
||||
LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.")
|
||||
return build_dataloader(dataset, batch=self.args.batch, workers=0) # required for batch loading
|
||||
return build_dataloader(dataset, batch=self.args.batch * 2, workers=0) # required for batch loading
|
||||
|
||||
@try_export
|
||||
def export_torchscript(self, prefix=colorstr("TorchScript:")):
|
||||
|
|
@ -679,6 +681,7 @@ class Exporter:
|
|||
import tensorrt as trt # noqa
|
||||
check_version(trt.__version__, "7.0.0", hard=True) # require tensorrt>=7.0.0
|
||||
|
||||
# Setup and checks
|
||||
LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...")
|
||||
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
|
||||
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
|
||||
|
|
@ -687,6 +690,7 @@ class Exporter:
|
|||
if self.args.verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
# Engine builder
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
workspace = int(self.args.workspace * (1 << 30))
|
||||
|
|
@ -696,10 +700,14 @@ class Exporter:
|
|||
config.max_workspace_size = workspace
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
half = builder.platform_has_fast_fp16 and self.args.half
|
||||
int8 = builder.platform_has_fast_int8 and self.args.int8
|
||||
# Read ONNX file
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(f_onnx):
|
||||
raise RuntimeError(f"failed to load ONNX file: {f_onnx}")
|
||||
|
||||
# Network inputs
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
for inp in inputs:
|
||||
|
|
@ -713,15 +721,67 @@ class Exporter:
|
|||
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'")
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
||||
opt_shape = (max(1, shape[0] // 2), *shape[1:]) # optimal input shape
|
||||
max_shape = (*shape[:2], *(max(1, self.args.workspace) * d for d in shape[2:])) # max input shape
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, min_shape, opt_shape, max_shape)
|
||||
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
|
||||
half = builder.platform_has_fast_fp16 and self.args.half
|
||||
LOGGER.info(f"{prefix} building FP{16 if half else 32} engine as {f}")
|
||||
if half:
|
||||
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {f}")
|
||||
if int8:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
config.set_calibration_profile(profile)
|
||||
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
||||
|
||||
class EngineCalibrator(trt.IInt8Calibrator):
|
||||
def __init__(
|
||||
self,
|
||||
dataset, # ultralytics.data.build.InfiniteDataLoader
|
||||
batch: int,
|
||||
cache: str = "",
|
||||
) -> None:
|
||||
trt.IInt8Calibrator.__init__(self)
|
||||
self.dataset = dataset
|
||||
self.data_iter = iter(dataset)
|
||||
self.algo = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2
|
||||
self.batch = batch
|
||||
self.cache = Path(cache)
|
||||
|
||||
def get_algorithm(self) -> trt.CalibrationAlgoType:
|
||||
"""Get the calibration algorithm to use."""
|
||||
return self.algo
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""Get the batch size to use for calibration."""
|
||||
return self.batch or 1
|
||||
|
||||
def get_batch(self, names) -> list:
|
||||
"""Get the next batch to use for calibration, as a list of device memory pointers."""
|
||||
try:
|
||||
im0s = next(self.data_iter)["img"] / 255.0
|
||||
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
|
||||
return [int(im0s.data_ptr())]
|
||||
except StopIteration:
|
||||
# Return [] or None, signal to TensorRT there is no calibration data remaining
|
||||
return None
|
||||
|
||||
def read_calibration_cache(self) -> bytes:
|
||||
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
|
||||
if self.cache.exists() and self.cache.suffix == ".cache":
|
||||
return self.cache.read_bytes()
|
||||
|
||||
def write_calibration_cache(self, cache) -> None:
|
||||
"""Write calibration cache to disk."""
|
||||
_ = self.cache.write_bytes(cache)
|
||||
|
||||
# Load dataset w/ builder (for batching) and calibrate
|
||||
dataset = self.get_int8_calibration_dataloader(prefix)
|
||||
config.int8_calibrator = EngineCalibrator(
|
||||
dataset=dataset,
|
||||
batch=2 * self.args.batch,
|
||||
cache=self.file.with_suffix(".cache"),
|
||||
)
|
||||
|
||||
elif half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Free CUDA memory
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue