ultralytics 8.0.195 NVIDIA Triton Inference Server support (#5257)
Co-authored-by: TheConstant3 <46416203+TheConstant3@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
40e3923cfc
commit
c7aa83da31
21 changed files with 349 additions and 98 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.194'
|
||||
__version__ = '8.0.195'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
|
|
|||
|
|
@ -81,6 +81,12 @@ class Model(nn.Module):
|
|||
self.session = HUBTrainingSession(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Check if Triton Server model
|
||||
elif self.is_triton_model(model):
|
||||
self.model = model
|
||||
self.task = task
|
||||
return
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
|
||||
|
|
@ -94,6 +100,13 @@ class Model(nn.Module):
|
|||
"""Calls the 'predict' function with given arguments to perform object detection."""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model):
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
from urllib.parse import urlsplit
|
||||
url = urlsplit(model)
|
||||
return url.netloc and url.path and url.scheme in {'http', 'grfc'}
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
"""Check if the provided model is a HUB model."""
|
||||
|
|
|
|||
|
|
@ -15,13 +15,14 @@ class FastSAMPredictor(DetectionPredictor):
|
|||
self.args.task = 'segment'
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=len(self.model.names),
|
||||
classes=self.args.classes)
|
||||
p = ops.non_max_suppression(
|
||||
preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=1, # set to 1 class since SAM has no class predictions
|
||||
classes=self.args.classes)
|
||||
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
|
||||
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
|
||||
full_box = full_box.view(1, -1)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import platform
|
|||
import zipfile
|
||||
from collections import OrderedDict, namedtuple
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
|
@ -32,8 +31,8 @@ def check_class_names(names):
|
|||
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
|
||||
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
|
||||
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
|
||||
map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
|
||||
names = {k: map[v] for k, v in names.items()}
|
||||
names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
|
||||
names = {k: names_map[v] for k, v in names.items()}
|
||||
return names
|
||||
|
||||
|
||||
|
|
@ -274,13 +273,9 @@ class AutoBackend(nn.Module):
|
|||
net.load_model(str(w.with_suffix('.bin')))
|
||||
metadata = w.parent / 'metadata.yaml'
|
||||
elif triton: # NVIDIA Triton Inference Server
|
||||
"""TODO
|
||||
check_requirements('tritonclient[all]')
|
||||
from utils.triton import TritonRemoteModel
|
||||
model = TritonRemoteModel(url=w)
|
||||
nhwc = model.runtime.startswith("tensorflow")
|
||||
"""
|
||||
raise NotImplementedError('Triton Inference Server is not currently supported.')
|
||||
from ultralytics.utils.triton import TritonRemoteModel
|
||||
model = TritonRemoteModel(w)
|
||||
else:
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
raise TypeError(f"model='{w}' is not a supported model format. "
|
||||
|
|
@ -395,6 +390,7 @@ class AutoBackend(nn.Module):
|
|||
ex.extract(output_name, mat_out)
|
||||
y.append(np.array(mat_out)[None])
|
||||
elif self.triton: # NVIDIA Triton Inference Server
|
||||
im = im.cpu().numpy() # torch to numpy
|
||||
y = self.model(im)
|
||||
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
|
||||
im = im.cpu().numpy()
|
||||
|
|
@ -498,6 +494,8 @@ class AutoBackend(nn.Module):
|
|||
if any(types):
|
||||
triton = False
|
||||
else:
|
||||
url = urlparse(p) # if url may be Triton inference server
|
||||
triton = all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
|
||||
from urllib.parse import urlsplit
|
||||
url = urlsplit(p)
|
||||
triton = url.netloc and url.path and url.scheme in {'http', 'grfc'}
|
||||
|
||||
return types + [triton]
|
||||
|
|
|
|||
|
|
@ -705,7 +705,7 @@ def remove_colorstr(input_string):
|
|||
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
|
||||
>>> 'hello world'
|
||||
"""
|
||||
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
|
||||
ansi_escape = re.compile(r'\x1B(?:[@-Z\\\-_]|\[[0-9]*[ -/]*[@-~])')
|
||||
return ansi_escape.sub('', input_string)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
"""
|
||||
Model validation metrics
|
||||
"""
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
|
|
|||
86
ultralytics/utils/triton.py
Normal file
86
ultralytics/utils/triton.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from typing import List
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""Client for interacting with a remote Triton Inference Server model.
|
||||
|
||||
Attributes:
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
url (str): The URL of the Triton server.
|
||||
triton_client: The Triton client (either HTTP or gRPC).
|
||||
InferInput: The input class for the Triton client.
|
||||
InferRequestedOutput: The output request class for the Triton client.
|
||||
input_formats (List[str]): The data types of the model inputs.
|
||||
np_input_formats (List[type]): The numpy data types of the model inputs.
|
||||
input_names (List[str]): The names of the model inputs.
|
||||
output_names (List[str]): The names of the model outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
|
||||
"""
|
||||
Initialize the TritonRemoteModel.
|
||||
|
||||
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
||||
<scheme>://<netloc>/<endpoint>/<task_name>
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Triton server.
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
scheme (str): The communication scheme ('http' or 'grpc').
|
||||
"""
|
||||
if not endpoint and not scheme: # Parse all args from URL string
|
||||
splits = urlsplit(url)
|
||||
endpoint = splits.path.strip('/').split('/')[0]
|
||||
scheme = splits.scheme
|
||||
url = splits.netloc
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.url = url
|
||||
|
||||
# Choose the Triton client based on the communication scheme
|
||||
if scheme == 'http':
|
||||
import tritonclient.http as client # noqa
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint)
|
||||
else:
|
||||
import tritonclient.grpc as client # noqa
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
|
||||
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
|
||||
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
|
||||
self.input_formats = [x['data_type'] for x in config['input']]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x['name'] for x in config['input']]
|
||||
self.output_names = [x['name'] for x in config['output']]
|
||||
|
||||
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
|
||||
"""
|
||||
Call the model with the given inputs.
|
||||
|
||||
Args:
|
||||
*inputs (List[np.ndarray]): Input data to the model.
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]: Model outputs.
|
||||
"""
|
||||
infer_inputs = []
|
||||
input_format = inputs[0].dtype
|
||||
for i, x in enumerate(inputs):
|
||||
if x.dtype != self.np_input_formats[i]:
|
||||
x = x.astype(self.np_input_formats[i])
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
|
||||
infer_input.set_data_from_numpy(x)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
|
||||
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
|
||||
|
||||
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
|
||||
Loading…
Add table
Add a link
Reference in a new issue