diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index e525159..1698a6e 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -17,7 +17,7 @@ ort-dylib = ["ort/load-dynamic"] [dependencies] libtqsm = "0.6.1" log = "0.4.18" -ndarray = "0.16.1" +ndarray = "0.17.1" once_cell = "1.18.0" rayon = { version = "1.8.1", optional = true } serde = { version = "1.0.160", features = ["derive"] } @@ -25,12 +25,16 @@ serde_json = "1.0.89" thiserror = "1.0.47" [dependencies.ort] -version = "=2.0.0-rc.9" +version = "=2.0.0-rc.11" optional = true +default-features = false +features = ["std", "load-dynamic"] [dependencies.ort-sys] -version = "=2.0.0-rc.9" +version = "=2.0.0-rc.11" optional = true [dev-dependencies.ort] -version = "=2.0.0-rc.9" +version = "=2.0.0-rc.11" +default-features = false +features = ["std", "load-dynamic"] diff --git a/crates/core/src/backend/ort.rs b/crates/core/src/backend/ort.rs index 05a274d..139a247 100644 --- a/crates/core/src/backend/ort.rs +++ b/crates/core/src/backend/ort.rs @@ -1,7 +1,8 @@ use crate::{InferenceEngine, LibtashkeelError, LibtashkeelResult}; use ndarray::{Array1, Array2}; -use ort::{session::{Session, builder::GraphOptimizationLevel,}}; +use ort::{session::{Session, builder::GraphOptimizationLevel,}, value::Value}; use std::path::Path; +use std::sync::Mutex; impl From for LibtashkeelError { fn from(other: ort::Error) -> Self { @@ -13,7 +14,7 @@ impl From for LibtashkeelError { } fn ort_session_run( - session: &Session, + session: &Mutex, input_ids: Vec, diac_ids: Vec, seq_length: usize, @@ -23,16 +24,20 @@ fn ort_session_run( let input_length = Array1::::from_iter([seq_length as i64]); let (target_ids, logits): (Vec, Vec) = { + let input_ids_value = Value::from_array(input_ids)?; + let diac_ids_value = Value::from_array(diac_ids)?; + let input_length_value = Value::from_array(input_length)?; let inputs = ort::inputs![ - input_ids, - diac_ids, - input_length, - ]?; + input_ids_value, + diac_ids_value, + input_length_value, + ]; + let mut session = session.lock().unwrap(); let outputs = session.run(inputs)?; - let target_ids = outputs[0].try_extract_tensor::()?; - let logits = outputs[1].try_extract_tensor::()?; - let target_ids_vec = Vec::from_iter(target_ids.view().iter().copied()); - let logits_vec = Vec::from_iter(logits.view().iter().copied()); + let (_, target_ids) = outputs[0].try_extract_tensor::()?; + let (_, logits) = outputs[1].try_extract_tensor::()?; + let target_ids_vec = Vec::from_iter(target_ids.iter().copied()); + let logits_vec = Vec::from_iter(logits.iter().copied()); (target_ids_vec, logits_vec) }; @@ -41,7 +46,7 @@ fn ort_session_run( const MODEL_BYTES: &[u8] = include_bytes!("../../data/ort/model.onnx"); -pub struct OrtEngine(Session); +pub struct OrtEngine(Mutex); impl OrtEngine { pub fn from_bytes(model_bytes: &[u8]) -> LibtashkeelResult { @@ -52,7 +57,7 @@ impl OrtEngine { .with_intra_threads(2)? .commit_from_memory(model_bytes)?; - Ok(Self(session)) + Ok(Self(Mutex::new(session))) } pub fn from_path(model_path: impl AsRef) -> LibtashkeelResult { let session = Session::builder()? @@ -64,7 +69,7 @@ impl OrtEngine { // .with_intra_threads(2)? .commit_from_file(model_path)?; - Ok(Self(session)) + Ok(Self(Mutex::new(session))) } pub fn with_bundled_model() -> LibtashkeelResult { Self::from_bytes(MODEL_BYTES)