ultralytics 8.1.46 add TensorRT 10 support (#9516)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: 九是否随意的称呼 <1069679911@qq.com>
This commit is contained in:
parent
ea03db9984
commit
4ffd6ee6d7
4 changed files with 77 additions and 32 deletions
|
|
@ -234,23 +234,47 @@ class AutoBackend(nn.Module):
|
|||
meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length
|
||||
metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata
|
||||
model = runtime.deserialize_cuda_engine(f.read()) # read engine
|
||||
context = model.create_execution_context()
|
||||
|
||||
# Model context
|
||||
try:
|
||||
context = model.create_execution_context()
|
||||
except Exception as e: # model is None
|
||||
LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n")
|
||||
raise e
|
||||
|
||||
bindings = OrderedDict()
|
||||
output_names = []
|
||||
fp16 = False # default updated below
|
||||
dynamic = False
|
||||
for i in range(model.num_bindings):
|
||||
name = model.get_binding_name(i)
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
if model.binding_is_input(i):
|
||||
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
||||
dynamic = True
|
||||
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else: # output
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_binding_shape(i))
|
||||
is_trt10 = not hasattr(model, "num_bindings")
|
||||
num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings)
|
||||
for i in num:
|
||||
if is_trt10:
|
||||
name = model.get_tensor_name(i)
|
||||
dtype = trt.nptype(model.get_tensor_dtype(name))
|
||||
is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT
|
||||
if is_input:
|
||||
if -1 in tuple(model.get_tensor_shape(name)):
|
||||
dynamic = True
|
||||
context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else:
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_tensor_shape(name))
|
||||
else: # TensorRT < 10.0
|
||||
name = model.get_binding_name(i)
|
||||
dtype = trt.nptype(model.get_binding_dtype(i))
|
||||
is_input = model.binding_is_input(i)
|
||||
if model.binding_is_input(i):
|
||||
if -1 in tuple(model.get_binding_shape(i)): # dynamic
|
||||
dynamic = True
|
||||
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1]))
|
||||
if dtype == np.float16:
|
||||
fp16 = True
|
||||
else:
|
||||
output_names.append(name)
|
||||
shape = tuple(context.get_binding_shape(i))
|
||||
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
|
||||
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
|
||||
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
|
||||
|
|
@ -463,13 +487,20 @@ class AutoBackend(nn.Module):
|
|||
|
||||
# TensorRT
|
||||
elif self.engine:
|
||||
if self.dynamic and im.shape != self.bindings["images"].shape:
|
||||
i = self.model.get_binding_index("images")
|
||||
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
|
||||
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
||||
for name in self.output_names:
|
||||
i = self.model.get_binding_index(name)
|
||||
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
||||
if self.dynamic or im.shape != self.bindings["images"].shape:
|
||||
if self.is_trt10:
|
||||
self.context.set_input_shape("images", im.shape)
|
||||
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
||||
for name in self.output_names:
|
||||
self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name)))
|
||||
else:
|
||||
i = self.model.get_binding_index("images")
|
||||
self.context.set_binding_shape(i, im.shape)
|
||||
self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape)
|
||||
for name in self.output_names:
|
||||
i = self.model.get_binding_index(name)
|
||||
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
|
||||
|
||||
s = self.bindings["images"].shape
|
||||
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
|
||||
self.binding_addrs["images"] = int(im.data_ptr())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue