Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@ ort-dylib = ["ort/load-dynamic"]
[dependencies]
libtqsm = "0.6.1"
log = "0.4.18"
ndarray = "0.16.1"
ndarray = "0.17.1"
once_cell = "1.18.0"
rayon = { version = "1.8.1", optional = true }
serde = { version = "1.0.160", features = ["derive"] }
serde_json = "1.0.89"
thiserror = "1.0.47"

[dependencies.ort]
version = "=2.0.0-rc.9"
version = "=2.0.0-rc.11"
optional = true
default-features = false
features = ["std", "load-dynamic"]

[dependencies.ort-sys]
version = "=2.0.0-rc.9"
version = "=2.0.0-rc.11"
optional = true

[dev-dependencies.ort]
version = "=2.0.0-rc.9"
version = "=2.0.0-rc.11"
default-features = false
features = ["std", "load-dynamic"]
31 changes: 18 additions & 13 deletions crates/core/src/backend/ort.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::{InferenceEngine, LibtashkeelError, LibtashkeelResult};
use ndarray::{Array1, Array2};
use ort::{session::{Session, builder::GraphOptimizationLevel,}};
use ort::{session::{Session, builder::GraphOptimizationLevel,}, value::Value};
use std::path::Path;
use std::sync::Mutex;

impl From<ort::Error> for LibtashkeelError {
fn from(other: ort::Error) -> Self {
Expand All @@ -13,7 +14,7 @@ impl From<ort::Error> for LibtashkeelError {
}

fn ort_session_run(
session: &Session,
session: &Mutex<Session>,
input_ids: Vec<i64>,
diac_ids: Vec<i64>,
seq_length: usize,
Expand All @@ -23,16 +24,20 @@ fn ort_session_run(
let input_length = Array1::<i64>::from_iter([seq_length as i64]);

let (target_ids, logits): (Vec<u8>, Vec<f32>) = {
let input_ids_value = Value::from_array(input_ids)?;
let diac_ids_value = Value::from_array(diac_ids)?;
let input_length_value = Value::from_array(input_length)?;
let inputs = ort::inputs![
input_ids,
diac_ids,
input_length,
]?;
input_ids_value,
diac_ids_value,
input_length_value,
];
let mut session = session.lock().unwrap();
let outputs = session.run(inputs)?;
let target_ids = outputs[0].try_extract_tensor::<u8>()?;
let logits = outputs[1].try_extract_tensor::<f32>()?;
let target_ids_vec = Vec::from_iter(target_ids.view().iter().copied());
let logits_vec = Vec::from_iter(logits.view().iter().copied());
let (_, target_ids) = outputs[0].try_extract_tensor::<u8>()?;
let (_, logits) = outputs[1].try_extract_tensor::<f32>()?;
let target_ids_vec = Vec::from_iter(target_ids.iter().copied());
let logits_vec = Vec::from_iter(logits.iter().copied());
(target_ids_vec, logits_vec)
};

Expand All @@ -41,7 +46,7 @@ fn ort_session_run(

const MODEL_BYTES: &[u8] = include_bytes!("../../data/ort/model.onnx");

pub struct OrtEngine(Session);
pub struct OrtEngine(Mutex<Session>);

impl OrtEngine {
pub fn from_bytes(model_bytes: &[u8]) -> LibtashkeelResult<OrtEngine> {
Expand All @@ -52,7 +57,7 @@ impl OrtEngine {
.with_intra_threads(2)?
.commit_from_memory(model_bytes)?;

Ok(Self(session))
Ok(Self(Mutex::new(session)))
}
pub fn from_path(model_path: impl AsRef<Path>) -> LibtashkeelResult<Self> {
let session = Session::builder()?
Expand All @@ -64,7 +69,7 @@ impl OrtEngine {
// .with_intra_threads(2)?
.commit_from_file(model_path)?;

Ok(Self(session))
Ok(Self(Mutex::new(session)))
}
pub fn with_bundled_model() -> LibtashkeelResult<OrtEngine> {
Self::from_bytes(MODEL_BYTES)
Expand Down
Loading