ultralytics 8.0.202 sort Triton model outputs (#5945)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mike Tune <mtuneoff@gmail.com>
This commit is contained in:
parent
a05edfbc27
commit
e58db228c2
10 changed files with 138 additions and 133 deletions
|
|
@ -53,10 +53,13 @@ class TritonRemoteModel:
|
|||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
|
||||
|
||||
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
||||
config['output'] = sorted(config['output'], key=lambda x: x.get('name'))
|
||||
|
||||
# Define model attributes
|
||||
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
|
||||
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
|
||||
self.input_formats = [x['data_type'] for x in config['input']]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x['name'] for x in config['input']]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue