Example ORT==2.0.0-rs.5 to support onnxruntime==1.19.x (#16962)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
Yan_Mr 2024-10-29 20:12:15 +08:00 committed by GitHub
parent b0c18b7190
commit 235f2d95af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 172 additions and 101 deletions

View file

@ -2,11 +2,13 @@ use anyhow::Result;
use clap::ValueEnum;
use half::f16;
use ndarray::{Array, CowArray, IxDyn};
use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions};
use ort::tensor::TensorElementDataType;
use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value};
use ort::{
CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
TensorRTExecutionProvider,
};
use ort::{Session, SessionBuilder};
use ort::{TensorElementType, ValueType};
use regex::Regex;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
pub enum YOLOTask {
// YOLO tasks
@ -19,9 +21,9 @@ pub enum YOLOTask {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum OrtEP {
// ONNXRuntime execution provider
Cpu,
Cuda(u32),
Trt(u32),
CPU,
CUDA(i32),
Trt(i32),
}
#[derive(Debug)]
@ -44,8 +46,9 @@ impl Default for Batch {
#[derive(Debug, Default)]
pub struct OrtInputs {
// ONNX model inputs attrs
pub shapes: Vec<Vec<i32>>,
pub dtypes: Vec<TensorElementDataType>,
pub shapes: Vec<Vec<i64>>,
//pub dtypes: Vec<TensorElementDataType>,
pub dtypes: Vec<TensorElementType>,
pub names: Vec<String>,
pub sizes: Vec<Vec<u32>>,
}
@ -56,12 +59,19 @@ impl OrtInputs {
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
/* let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
shapes.push(shape); */
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
dtypes.push(ty.clone());
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("不支持的数据格式, {} - {}", file!(), line!());
}
//dtypes.push(i.input_type);
names.push(i.name.clone());
}
Self {
@ -97,12 +107,14 @@ pub struct OrtBackend {
impl OrtBackend {
pub fn build(args: OrtConfig) -> Result<Self> {
// build env & session
let env = Environment::builder()
.with_name("YOLOv8")
.with_log_level(ort::LoggingLevel::Verbose)
.build()?
.into_arc();
let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
// in version 2.x environment is removed
/* let env = ort::EnvironmentBuilder
::with_name("YOLOv8")
.build()?
.into_arc(); */
let sessionbuilder = SessionBuilder::new()?;
let session = sessionbuilder.commit_from_file(&args.f)?;
//let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
// get inputs
let mut inputs = OrtInputs::new(&session);
@ -142,16 +154,19 @@ impl OrtBackend {
// build provider
let (ep, provider) = match args.ep {
OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id),
OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
_ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())),
_ => (
OrtEP::CPU,
ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
),
};
// build session again with the new provider
let session = SessionBuilder::new(&env)?
let session = SessionBuilder::new()?
// .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_execution_providers([provider])?
.with_model_from_file(args.f)?;
.commit_from_file(args.f)?;
// task: using given one or guessing
let task = match args.task {
@ -185,57 +200,58 @@ impl OrtBackend {
pub fn fetch_inputs_from_session(
session: &Session,
) -> (Vec<Vec<i32>>, Vec<TensorElementDataType>, Vec<String>) {
) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
// get inputs attrs from ONNX model
let mut shapes = Vec::new();
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
dtypes.push(ty.clone());
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("不支持的数据格式, {} - {}", file!(), line!());
}
names.push(i.name.clone());
}
(shapes, dtypes, names)
}
pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) {
// set CUDA
if ExecutionProvider::CUDA(Default::default()).is_available() {
pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
if let Ok(true) = cuda_provider.is_available() {
(
OrtEP::Cuda(device_id),
ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
device_id,
..Default::default()
}),
OrtEP::CUDA(device_id),
ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
)
} else {
println!("> CUDA is not available! Using CPU.");
(OrtEP::Cpu, ExecutionProvider::CPU(Default::default()))
(
OrtEP::CPU,
ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
)
}
}
pub fn set_ep_trt(
device_id: u32,
device_id: i32,
fp16: bool,
batch: &Batch,
inputs: &OrtInputs,
) -> (OrtEP, ExecutionProvider) {
) -> (OrtEP, ExecutionProviderDispatch) {
// set TensorRT
if ExecutionProvider::TensorRT(Default::default()).is_available() {
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
// dtype match checking
if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 {
//trt_provider.
if let Ok(true) = trt_provider.is_available() {
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
panic!(
"Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
inputs.dtypes[0]
);
}
// dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
let mut opt_string = String::new();
let mut min_string = String::new();
@ -251,17 +267,16 @@ impl OrtBackend {
let _ = opt_string.pop();
let _ = min_string.pop();
let _ = max_string.pop();
let trt_provider = trt_provider
.with_profile_opt_shapes(opt_string)
.with_profile_min_shapes(min_string)
.with_profile_max_shapes(max_string)
.with_fp16(fp16)
.with_timing_cache(true);
(
OrtEP::Trt(device_id),
ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions {
device_id,
fp16_enable: fp16,
timing_cache_enable: true,
profile_min_shapes: min_string,
profile_max_shapes: max_string,
profile_opt_shapes: opt_string,
..Default::default()
}),
ExecutionProviderDispatch::from(trt_provider),
)
} else {
println!("> TensorRT is not available! Try using CUDA...");
@ -283,8 +298,8 @@ impl OrtBackend {
pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
// ORT inference
match self.dtype() {
TensorElementDataType::Float16 => self.run_fp16(xs, profile),
TensorElementDataType::Float32 => self.run_fp32(xs, profile),
TensorElementType::Float16 => self.run_fp16(xs, profile),
TensorElementType::Float32 => self.run_fp32(xs, profile),
_ => todo!(),
}
}
@ -300,14 +315,13 @@ impl OrtBackend {
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
let ys = self.session.run(ort::inputs![xs.view()]?)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
@ -315,21 +329,22 @@ impl OrtBackend {
// d2h
Ok(ys
.iter()
.map(|x| {
.map(|(_k, v)| {
// d2h
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
let v = v.try_extract_tensor().unwrap();
//let v = v.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
// f16->f32
let t_ = std::time::Instant::now();
let x = x.mapv(f16::to_f32);
let v = v.mapv(f16::to_f32);
if profile {
println!("[ORT f16->f32]: {:?}", t_.elapsed());
}
x
v
})
.collect::<Vec<Array<_, _>>>())
}
@ -338,14 +353,13 @@ impl OrtBackend {
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
let ys = self.session.run(ort::inputs![xs.view()]?)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
@ -353,39 +367,44 @@ impl OrtBackend {
// d2h
Ok(ys
.iter()
.map(|x| {
.map(|(_k, v)| {
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
let v = v.try_extract_tensor::<f32>().unwrap().into_owned();
//let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
x
v
})
.collect::<Vec<Array<_, _>>>())
}
pub fn output_shapes(&self) -> Vec<Vec<i32>> {
pub fn output_shapes(&self) -> Vec<Vec<i64>> {
let mut shapes = Vec::new();
for o in &self.session.outputs {
let shape: Vec<_> = o
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
for output in &self.session.outputs {
if let ValueType::Tensor { ty: _, dimensions } = &output.output_type {
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("not support data format, {} - {}", file!(), line!());
}
}
shapes
}
pub fn output_dtypes(&self) -> Vec<TensorElementDataType> {
pub fn output_dtypes(&self) -> Vec<TensorElementType> {
let mut dtypes = Vec::new();
self.session
.outputs
.iter()
.for_each(|x| dtypes.push(x.output_type));
for output in &self.session.outputs {
if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type {
dtypes.push(ty.clone());
} else {
panic!("not support data format, {} - {}", file!(), line!());
}
}
dtypes
}
pub fn input_shapes(&self) -> &Vec<Vec<i32>> {
pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
&self.inputs.shapes
}
@ -393,11 +412,11 @@ impl OrtBackend {
&self.inputs.names
}
pub fn input_dtypes(&self) -> &Vec<TensorElementDataType> {
pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
&self.inputs.dtypes
}
pub fn dtype(&self) -> TensorElementDataType {
pub fn dtype(&self) -> TensorElementType {
self.input_dtypes()[0]
}