Example ORT==2.0.0-rs.5 to support onnxruntime==1.19.x (#16962)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
b0c18b7190
commit
235f2d95af
7 changed files with 172 additions and 101 deletions
|
|
@ -9,11 +9,11 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] }
|
||||
imageproc = { version = "0.23.0", default-features = false }
|
||||
ndarray = { version = "0.15.6" }
|
||||
ort = { version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"] }
|
||||
rusttype = { version = "0.9", default-features = false }
|
||||
image = { version = "0.25.2"}
|
||||
imageproc = { version = "0.25.0"}
|
||||
ndarray = { version = "0.16" }
|
||||
ort = { version = "2.0.0-rc.5", features = ["cuda", "tensorrt"]}
|
||||
rusttype = { version = "0.9.3" }
|
||||
anyhow = { version = "1.0.75" }
|
||||
regex = { version = "1.5.4" }
|
||||
rand = { version = "0.8.5" }
|
||||
|
|
@ -21,3 +21,4 @@ chrono = { version = "0.4.30" }
|
|||
half = { version = "2.3.1" }
|
||||
dirs = { version = "5.0.1" }
|
||||
ureq = { version = "2.9.1" }
|
||||
ab_glyph = "0.2.29"
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ This repository provides a Rust demo for performing YOLOv8 tasks like `Classific
|
|||
## Recently Updated
|
||||
|
||||
- Add YOLOv8-OBB demo
|
||||
- Update ONNXRuntime to 1.17.x
|
||||
- Update ONNXRuntime to 1.19.x
|
||||
|
||||
Newly updated YOLOv8 example code is located in this repository (https://github.com/jamjamjon/usls/tree/main/examples/yolo)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ pub struct Args {
|
|||
|
||||
/// device id
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub device_id: u32,
|
||||
pub device_id: i32,
|
||||
|
||||
/// using TensorRT EP
|
||||
#[arg(long)]
|
||||
|
|
|
|||
|
|
@ -117,3 +117,45 @@ pub fn check_font(font: &str) -> rusttype::Font<'static> {
|
|||
let buffer = std::fs::read(font_path).unwrap();
|
||||
rusttype::Font::try_from_vec(buffer).unwrap()
|
||||
}
|
||||
|
||||
|
||||
use ab_glyph::FontArc;
|
||||
pub fn load_font() -> FontArc{
|
||||
use std::path::Path;
|
||||
let font_path = Path::new("./font/Arial.ttf");
|
||||
match font_path.try_exists() {
|
||||
Ok(true) => {
|
||||
let buffer = std::fs::read(font_path).unwrap();
|
||||
FontArc::try_from_vec(buffer).unwrap()
|
||||
},
|
||||
Ok(false) => {
|
||||
std::fs::create_dir_all("./font").unwrap();
|
||||
println!("Downloading font...");
|
||||
let source_url = "https://ultralytics.com/assets/Arial.ttf";
|
||||
let resp = ureq::get(source_url)
|
||||
.timeout(std::time::Duration::from_secs(500))
|
||||
.call()
|
||||
.unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
|
||||
|
||||
// read to buffer
|
||||
let mut buffer = vec![];
|
||||
let total_size = resp
|
||||
.header("Content-Length")
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap();
|
||||
let _reader = resp
|
||||
.into_reader()
|
||||
.take(total_size)
|
||||
.read_to_end(&mut buffer)
|
||||
.unwrap();
|
||||
// save
|
||||
let mut fd = std::fs::File::create(font_path).unwrap();
|
||||
fd.write_all(&buffer).unwrap();
|
||||
println!("Font saved at: {:?}", font_path.display());
|
||||
FontArc::try_from_vec(buffer).unwrap()
|
||||
},
|
||||
Err(e) => {
|
||||
panic!("Failed to load font {}", e);
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
@ -6,7 +6,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
let args = Args::parse();
|
||||
|
||||
// 1. load image
|
||||
let x = image::io::Reader::open(&args.source)?
|
||||
let x = image::ImageReader::open(&args.source)?
|
||||
.with_guessed_format()?
|
||||
.decode()?;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
#![allow(clippy::type_complexity)]
|
||||
|
||||
use ab_glyph::FontArc;
|
||||
use anyhow::Result;
|
||||
use image::{DynamicImage, GenericImageView, ImageBuffer};
|
||||
use ndarray::{s, Array, Axis, IxDyn};
|
||||
|
|
@ -7,7 +8,7 @@ use rand::{thread_rng, Rng};
|
|||
use std::path::PathBuf;
|
||||
|
||||
use crate::{
|
||||
check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
|
||||
load_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
|
||||
OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
|
||||
};
|
||||
|
||||
|
|
@ -36,9 +37,9 @@ impl YOLOv8 {
|
|||
let ep = if config.trt {
|
||||
OrtEP::Trt(config.device_id)
|
||||
} else if config.cuda {
|
||||
OrtEP::Cuda(config.device_id)
|
||||
OrtEP::CUDA(config.device_id)
|
||||
} else {
|
||||
OrtEP::Cpu
|
||||
OrtEP::CPU
|
||||
};
|
||||
|
||||
// batch
|
||||
|
|
@ -330,12 +331,19 @@ impl YOLOv8 {
|
|||
|
||||
// coefs * proto -> mask
|
||||
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
|
||||
let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
|
||||
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
|
||||
|
||||
let proto = proto.to_owned();
|
||||
let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw)
|
||||
let mask = coefs.dot(&proto); // (nh, nw, n)
|
||||
let mask = mask.to_shape((nh, nw, 1))?;
|
||||
|
||||
// build image from ndarray
|
||||
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
|
||||
match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
|
||||
match ImageBuffer::from_raw(
|
||||
nw as u32,
|
||||
nh as u32,
|
||||
mask.to_owned().into_raw_vec_and_offset().0,
|
||||
) {
|
||||
Some(image) => image,
|
||||
None => panic!("can not create image from ndarray"),
|
||||
};
|
||||
|
|
@ -410,7 +418,7 @@ impl YOLOv8 {
|
|||
skeletons: Option<&[(usize, usize)]>,
|
||||
) {
|
||||
// check font then load
|
||||
let font = check_font("Arial.ttf");
|
||||
let font: FontArc = load_font();
|
||||
for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
|
||||
let mut img = img0.to_rgb8();
|
||||
|
||||
|
|
@ -422,12 +430,13 @@ impl YOLOv8 {
|
|||
let legend_size = img.width().max(img.height()) / scale;
|
||||
let x = img.width() / 20;
|
||||
let y = img.height() / 20 + i as u32 * legend_size;
|
||||
|
||||
imageproc::drawing::draw_text_mut(
|
||||
&mut img,
|
||||
image::Rgb([0, 255, 0]),
|
||||
x as i32,
|
||||
y as i32,
|
||||
rusttype::Scale::uniform(legend_size as f32 - 1.),
|
||||
legend_size as f32,
|
||||
&font,
|
||||
&legend,
|
||||
);
|
||||
|
|
@ -454,7 +463,7 @@ impl YOLOv8 {
|
|||
image::Rgb(self.color_palette[bbox.id()].into()),
|
||||
bbox.xmin() as i32,
|
||||
(bbox.ymin() - legend_size as f32) as i32,
|
||||
rusttype::Scale::uniform(legend_size as f32 - 1.),
|
||||
legend_size as f32,
|
||||
&font,
|
||||
&legend,
|
||||
);
|
||||
|
|
@ -551,7 +560,7 @@ impl YOLOv8 {
|
|||
None => String::from(""),
|
||||
},
|
||||
self.engine.ep(),
|
||||
if let OrtEP::Cpu = self.engine.ep() {
|
||||
if let OrtEP::CPU = self.engine.ep() {
|
||||
""
|
||||
} else {
|
||||
"(May still fall back to CPU)"
|
||||
|
|
|
|||
|
|
@ -2,11 +2,13 @@ use anyhow::Result;
|
|||
use clap::ValueEnum;
|
||||
use half::f16;
|
||||
use ndarray::{Array, CowArray, IxDyn};
|
||||
use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions};
|
||||
use ort::tensor::TensorElementDataType;
|
||||
use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value};
|
||||
use ort::{
|
||||
CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
|
||||
TensorRTExecutionProvider,
|
||||
};
|
||||
use ort::{Session, SessionBuilder};
|
||||
use ort::{TensorElementType, ValueType};
|
||||
use regex::Regex;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
|
||||
pub enum YOLOTask {
|
||||
// YOLO tasks
|
||||
|
|
@ -19,9 +21,9 @@ pub enum YOLOTask {
|
|||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum OrtEP {
|
||||
// ONNXRuntime execution provider
|
||||
Cpu,
|
||||
Cuda(u32),
|
||||
Trt(u32),
|
||||
CPU,
|
||||
CUDA(i32),
|
||||
Trt(i32),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
@ -44,8 +46,9 @@ impl Default for Batch {
|
|||
#[derive(Debug, Default)]
|
||||
pub struct OrtInputs {
|
||||
// ONNX model inputs attrs
|
||||
pub shapes: Vec<Vec<i32>>,
|
||||
pub dtypes: Vec<TensorElementDataType>,
|
||||
pub shapes: Vec<Vec<i64>>,
|
||||
//pub dtypes: Vec<TensorElementDataType>,
|
||||
pub dtypes: Vec<TensorElementType>,
|
||||
pub names: Vec<String>,
|
||||
pub sizes: Vec<Vec<u32>>,
|
||||
}
|
||||
|
|
@ -56,12 +59,19 @@ impl OrtInputs {
|
|||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for i in session.inputs.iter() {
|
||||
let shape: Vec<i32> = i
|
||||
/* let shape: Vec<i32> = i
|
||||
.dimensions()
|
||||
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
|
||||
.collect();
|
||||
shapes.push(shape);
|
||||
dtypes.push(i.input_type);
|
||||
shapes.push(shape); */
|
||||
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
|
||||
dtypes.push(ty.clone());
|
||||
let shape = dimensions.clone();
|
||||
shapes.push(shape);
|
||||
} else {
|
||||
panic!("不支持的数据格式, {} - {}", file!(), line!());
|
||||
}
|
||||
//dtypes.push(i.input_type);
|
||||
names.push(i.name.clone());
|
||||
}
|
||||
Self {
|
||||
|
|
@ -97,12 +107,14 @@ pub struct OrtBackend {
|
|||
impl OrtBackend {
|
||||
pub fn build(args: OrtConfig) -> Result<Self> {
|
||||
// build env & session
|
||||
let env = Environment::builder()
|
||||
.with_name("YOLOv8")
|
||||
.with_log_level(ort::LoggingLevel::Verbose)
|
||||
.build()?
|
||||
.into_arc();
|
||||
let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
|
||||
// in version 2.x environment is removed
|
||||
/* let env = ort::EnvironmentBuilder
|
||||
::with_name("YOLOv8")
|
||||
.build()?
|
||||
.into_arc(); */
|
||||
let sessionbuilder = SessionBuilder::new()?;
|
||||
let session = sessionbuilder.commit_from_file(&args.f)?;
|
||||
//let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
|
||||
|
||||
// get inputs
|
||||
let mut inputs = OrtInputs::new(&session);
|
||||
|
|
@ -142,16 +154,19 @@ impl OrtBackend {
|
|||
|
||||
// build provider
|
||||
let (ep, provider) = match args.ep {
|
||||
OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id),
|
||||
OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
|
||||
OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
|
||||
_ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())),
|
||||
_ => (
|
||||
OrtEP::CPU,
|
||||
ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
|
||||
),
|
||||
};
|
||||
|
||||
// build session again with the new provider
|
||||
let session = SessionBuilder::new(&env)?
|
||||
let session = SessionBuilder::new()?
|
||||
// .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
|
||||
.with_execution_providers([provider])?
|
||||
.with_model_from_file(args.f)?;
|
||||
.commit_from_file(args.f)?;
|
||||
|
||||
// task: using given one or guessing
|
||||
let task = match args.task {
|
||||
|
|
@ -185,57 +200,58 @@ impl OrtBackend {
|
|||
|
||||
pub fn fetch_inputs_from_session(
|
||||
session: &Session,
|
||||
) -> (Vec<Vec<i32>>, Vec<TensorElementDataType>, Vec<String>) {
|
||||
) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
|
||||
// get inputs attrs from ONNX model
|
||||
let mut shapes = Vec::new();
|
||||
let mut dtypes = Vec::new();
|
||||
let mut names = Vec::new();
|
||||
for i in session.inputs.iter() {
|
||||
let shape: Vec<i32> = i
|
||||
.dimensions()
|
||||
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
|
||||
.collect();
|
||||
shapes.push(shape);
|
||||
dtypes.push(i.input_type);
|
||||
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
|
||||
dtypes.push(ty.clone());
|
||||
let shape = dimensions.clone();
|
||||
shapes.push(shape);
|
||||
} else {
|
||||
panic!("不支持的数据格式, {} - {}", file!(), line!());
|
||||
}
|
||||
names.push(i.name.clone());
|
||||
}
|
||||
(shapes, dtypes, names)
|
||||
}
|
||||
|
||||
pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) {
|
||||
// set CUDA
|
||||
if ExecutionProvider::CUDA(Default::default()).is_available() {
|
||||
pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
|
||||
let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
|
||||
if let Ok(true) = cuda_provider.is_available() {
|
||||
(
|
||||
OrtEP::Cuda(device_id),
|
||||
ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
|
||||
device_id,
|
||||
..Default::default()
|
||||
}),
|
||||
OrtEP::CUDA(device_id),
|
||||
ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
|
||||
)
|
||||
} else {
|
||||
println!("> CUDA is not available! Using CPU.");
|
||||
(OrtEP::Cpu, ExecutionProvider::CPU(Default::default()))
|
||||
(
|
||||
OrtEP::CPU,
|
||||
ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_ep_trt(
|
||||
device_id: u32,
|
||||
device_id: i32,
|
||||
fp16: bool,
|
||||
batch: &Batch,
|
||||
inputs: &OrtInputs,
|
||||
) -> (OrtEP, ExecutionProvider) {
|
||||
) -> (OrtEP, ExecutionProviderDispatch) {
|
||||
// set TensorRT
|
||||
if ExecutionProvider::TensorRT(Default::default()).is_available() {
|
||||
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
|
||||
let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
|
||||
|
||||
// dtype match checking
|
||||
if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 {
|
||||
//trt_provider.
|
||||
if let Ok(true) = trt_provider.is_available() {
|
||||
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
|
||||
if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
|
||||
panic!(
|
||||
"Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
|
||||
inputs.dtypes[0]
|
||||
);
|
||||
}
|
||||
|
||||
// dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
|
||||
let mut opt_string = String::new();
|
||||
let mut min_string = String::new();
|
||||
|
|
@ -251,17 +267,16 @@ impl OrtBackend {
|
|||
let _ = opt_string.pop();
|
||||
let _ = min_string.pop();
|
||||
let _ = max_string.pop();
|
||||
|
||||
let trt_provider = trt_provider
|
||||
.with_profile_opt_shapes(opt_string)
|
||||
.with_profile_min_shapes(min_string)
|
||||
.with_profile_max_shapes(max_string)
|
||||
.with_fp16(fp16)
|
||||
.with_timing_cache(true);
|
||||
(
|
||||
OrtEP::Trt(device_id),
|
||||
ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions {
|
||||
device_id,
|
||||
fp16_enable: fp16,
|
||||
timing_cache_enable: true,
|
||||
profile_min_shapes: min_string,
|
||||
profile_max_shapes: max_string,
|
||||
profile_opt_shapes: opt_string,
|
||||
..Default::default()
|
||||
}),
|
||||
ExecutionProviderDispatch::from(trt_provider),
|
||||
)
|
||||
} else {
|
||||
println!("> TensorRT is not available! Try using CUDA...");
|
||||
|
|
@ -283,8 +298,8 @@ impl OrtBackend {
|
|||
pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
|
||||
// ORT inference
|
||||
match self.dtype() {
|
||||
TensorElementDataType::Float16 => self.run_fp16(xs, profile),
|
||||
TensorElementDataType::Float32 => self.run_fp32(xs, profile),
|
||||
TensorElementType::Float16 => self.run_fp16(xs, profile),
|
||||
TensorElementType::Float32 => self.run_fp32(xs, profile),
|
||||
_ => todo!(),
|
||||
}
|
||||
}
|
||||
|
|
@ -300,14 +315,13 @@ impl OrtBackend {
|
|||
// h2d
|
||||
let t = std::time::Instant::now();
|
||||
let xs = CowArray::from(xs);
|
||||
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
|
||||
if profile {
|
||||
println!("[ORT H2D]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// run
|
||||
let t = std::time::Instant::now();
|
||||
let ys = self.session.run(xs)?;
|
||||
let ys = self.session.run(ort::inputs![xs.view()]?)?;
|
||||
if profile {
|
||||
println!("[ORT Inference]: {:?}", t.elapsed());
|
||||
}
|
||||
|
|
@ -315,21 +329,22 @@ impl OrtBackend {
|
|||
// d2h
|
||||
Ok(ys
|
||||
.iter()
|
||||
.map(|x| {
|
||||
.map(|(_k, v)| {
|
||||
// d2h
|
||||
let t = std::time::Instant::now();
|
||||
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
|
||||
let v = v.try_extract_tensor().unwrap();
|
||||
//let v = v.try_extract::<_>().unwrap().view().clone().into_owned();
|
||||
if profile {
|
||||
println!("[ORT D2H]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// f16->f32
|
||||
let t_ = std::time::Instant::now();
|
||||
let x = x.mapv(f16::to_f32);
|
||||
let v = v.mapv(f16::to_f32);
|
||||
if profile {
|
||||
println!("[ORT f16->f32]: {:?}", t_.elapsed());
|
||||
}
|
||||
x
|
||||
v
|
||||
})
|
||||
.collect::<Vec<Array<_, _>>>())
|
||||
}
|
||||
|
|
@ -338,14 +353,13 @@ impl OrtBackend {
|
|||
// h2d
|
||||
let t = std::time::Instant::now();
|
||||
let xs = CowArray::from(xs);
|
||||
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
|
||||
if profile {
|
||||
println!("[ORT H2D]: {:?}", t.elapsed());
|
||||
}
|
||||
|
||||
// run
|
||||
let t = std::time::Instant::now();
|
||||
let ys = self.session.run(xs)?;
|
||||
let ys = self.session.run(ort::inputs![xs.view()]?)?;
|
||||
if profile {
|
||||
println!("[ORT Inference]: {:?}", t.elapsed());
|
||||
}
|
||||
|
|
@ -353,39 +367,44 @@ impl OrtBackend {
|
|||
// d2h
|
||||
Ok(ys
|
||||
.iter()
|
||||
.map(|x| {
|
||||
.map(|(_k, v)| {
|
||||
let t = std::time::Instant::now();
|
||||
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
|
||||
let v = v.try_extract_tensor::<f32>().unwrap().into_owned();
|
||||
//let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
|
||||
if profile {
|
||||
println!("[ORT D2H]: {:?}", t.elapsed());
|
||||
}
|
||||
x
|
||||
v
|
||||
})
|
||||
.collect::<Vec<Array<_, _>>>())
|
||||
}
|
||||
|
||||
pub fn output_shapes(&self) -> Vec<Vec<i32>> {
|
||||
pub fn output_shapes(&self) -> Vec<Vec<i64>> {
|
||||
let mut shapes = Vec::new();
|
||||
for o in &self.session.outputs {
|
||||
let shape: Vec<_> = o
|
||||
.dimensions()
|
||||
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
|
||||
.collect();
|
||||
shapes.push(shape);
|
||||
for output in &self.session.outputs {
|
||||
if let ValueType::Tensor { ty: _, dimensions } = &output.output_type {
|
||||
let shape = dimensions.clone();
|
||||
shapes.push(shape);
|
||||
} else {
|
||||
panic!("not support data format, {} - {}", file!(), line!());
|
||||
}
|
||||
}
|
||||
shapes
|
||||
}
|
||||
|
||||
pub fn output_dtypes(&self) -> Vec<TensorElementDataType> {
|
||||
pub fn output_dtypes(&self) -> Vec<TensorElementType> {
|
||||
let mut dtypes = Vec::new();
|
||||
self.session
|
||||
.outputs
|
||||
.iter()
|
||||
.for_each(|x| dtypes.push(x.output_type));
|
||||
for output in &self.session.outputs {
|
||||
if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type {
|
||||
dtypes.push(ty.clone());
|
||||
} else {
|
||||
panic!("not support data format, {} - {}", file!(), line!());
|
||||
}
|
||||
}
|
||||
dtypes
|
||||
}
|
||||
|
||||
pub fn input_shapes(&self) -> &Vec<Vec<i32>> {
|
||||
pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
|
||||
&self.inputs.shapes
|
||||
}
|
||||
|
||||
|
|
@ -393,11 +412,11 @@ impl OrtBackend {
|
|||
&self.inputs.names
|
||||
}
|
||||
|
||||
pub fn input_dtypes(&self) -> &Vec<TensorElementDataType> {
|
||||
pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
|
||||
&self.inputs.dtypes
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> TensorElementDataType {
|
||||
pub fn dtype(&self) -> TensorElementType {
|
||||
self.input_dtypes()[0]
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue