diff --git a/docs/en/guides/triton-inference-server.md b/docs/en/guides/triton-inference-server.md index 0151cc07..67d419bf 100644 --- a/docs/en/guides/triton-inference-server.md +++ b/docs/en/guides/triton-inference-server.md @@ -48,6 +48,16 @@ from ultralytics import YOLO # Load a model model = YOLO("yolo11n.pt") # load an official model +# Retreive metadata during export +metadata = [] + + +def export_cb(exporter): + metadata.append(exporter.metadata) + + +model.add_callback("on_export_end", export_cb) + # Export the model onnx_file = model.export(format="onnx", dynamic=True) ``` @@ -107,7 +117,13 @@ The Triton Model Repository is a storage location where Triton can access and lo } } } - """ + parameters { + key: "metadata" + value: { + string_value: "%s" + } + } + """ % metadata[0] with open(triton_model_path / "config.pbtxt", "w") as f: f.write(data) diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 790bb406..9d19f6ab 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.43" +__version__ = "8.3.44" import os diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 8e9b74eb..b6df3753 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -462,6 +462,7 @@ class AutoBackend(nn.Module): from ultralytics.utils.triton import TritonRemoteModel model = TritonRemoteModel(w) + metadata = model.metadata # Any other format (unsupported) else: diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py index 3f873a6f..cc53ed57 100644 --- a/ultralytics/utils/triton.py +++ b/ultralytics/utils/triton.py @@ -66,6 +66,7 @@ class TritonRemoteModel: self.np_input_formats = [type_map[x] for x in self.input_formats] self.input_names = [x["name"] for x in config["input"]] self.output_names = [x["name"] for x in config["output"]] + self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: """