Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion soundevents/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "soundevents"
version = "0.2.0"
version = "0.2.1"
edition = "2024"
description = "Production-oriented Rust inference wrapper for CED AudioSet classifiers."
license.workspace = true
Expand Down Expand Up @@ -28,13 +28,20 @@ include = [
default = []
bundled-tiny = []

serde = ["dep:serde"]

[dependencies]
ort = "2.0.0-rc.12"
smol_str = "0.3"
thiserror = { workspace = true, features = ["default"] }

soundevents-dataset = { workspace = true, features = ["rated"] }

serde = { workspace = true, optional = true, features = ["derive"] }

[dev-dependencies]
serde_json = "1"

[lints.rust]
rust_2018_idioms = "warn"
single_use_lifetimes = "warn"
Expand Down
149 changes: 149 additions & 0 deletions soundevents/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use std::{
path::{Path, PathBuf},
};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

/// The expected input sample rate for CED models.
pub const SAMPLE_RATE_HZ: usize = 16_000;

Expand All @@ -29,10 +32,85 @@ pub const NUM_CLASSES: usize = RatedSoundEvent::events().len();
const BUNDLED_TINY_MODEL: &[u8] =
include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/tiny.onnx"));

#[cfg(feature = "serde")]
mod graph_optimization_level {
use super::GraphOptimizationLevel;
use serde::*;

#[derive(
Debug, Default, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize,
)]
#[serde(rename_all = "snake_case")]
enum OptimizationLevel {
#[default]
Disable,
Level1,
Level2,
Level3,
All,
}

impl From<GraphOptimizationLevel> for OptimizationLevel {
#[inline]
fn from(value: GraphOptimizationLevel) -> Self {
match value {
GraphOptimizationLevel::Disable => Self::Disable,
GraphOptimizationLevel::Level1 => Self::Level1,
GraphOptimizationLevel::Level2 => Self::Level2,
GraphOptimizationLevel::Level3 => Self::Level3,
GraphOptimizationLevel::All => Self::All,
}
}
}

impl From<OptimizationLevel> for GraphOptimizationLevel {
#[inline]
fn from(value: OptimizationLevel) -> Self {
match value {
OptimizationLevel::Disable => Self::Disable,
OptimizationLevel::Level1 => Self::Level1,
OptimizationLevel::Level2 => Self::Level2,
OptimizationLevel::Level3 => Self::Level3,
OptimizationLevel::All => Self::All,
}
}
}

#[cfg_attr(not(tarpaulin), inline(always))]
pub fn serialize<S>(level: &GraphOptimizationLevel, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
OptimizationLevel::from(*level).serialize(serializer)
}

#[cfg_attr(not(tarpaulin), inline(always))]
pub fn deserialize<'de, D>(deserializer: D) -> Result<GraphOptimizationLevel, D::Error>
where
D: Deserializer<'de>,
{
OptimizationLevel::deserialize(deserializer).map(Into::into)
}

#[cfg_attr(not(tarpaulin), inline(always))]
pub const fn default() -> GraphOptimizationLevel {
GraphOptimizationLevel::Disable
}
}

/// Options for constructing a [`Classifier`] from an ONNX model on disk.
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Options {
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
model_path: Option<PathBuf>,
#[cfg_attr(
feature = "serde",
serde(
default = "graph_optimization_level::default",
with = "graph_optimization_level"
)
)]
optimization_level: GraphOptimizationLevel,
}

Expand Down Expand Up @@ -1565,4 +1643,75 @@ mod tests {
Restriction::try_from("bogus").expect_err("unknown token surfaced");
assert_eq!(err.name(), "bogus");
}

#[cfg(feature = "serde")]
#[test]
fn test_serde() {
let opts = Options::default()
.with_model_path("some/model.onnx")
.with_optimization_level(GraphOptimizationLevel::Level2);
let serialized = serde_json::to_string(&opts).expect("serialize options");
let deserialized: Options = serde_json::from_str(&serialized).expect("deserialize options");
assert_eq!(opts.model_path, deserialized.model_path);
assert_eq!(opts.optimization_level, deserialized.optimization_level);

let default_deserialized: Options =
serde_json::from_str("{}").expect("deserialize default options");
assert!(default_deserialized.model_path.is_none());
assert!(matches!(
default_deserialized.optimization_level,
GraphOptimizationLevel::Disable
));

// level1
let level1_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level1);
let level1_serialized = serde_json::to_string(&level1_opts).expect("serialize level1 options");
let level1_deserialized: Options =
serde_json::from_str(&level1_serialized).expect("deserialize level1 options");
assert!(matches!(
level1_deserialized.optimization_level,
GraphOptimizationLevel::Level1
));

// level2
let level2_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level2);
let level2_serialized = serde_json::to_string(&level2_opts).expect("serialize level2 options");
let level2_deserialized: Options =
serde_json::from_str(&level2_serialized).expect("deserialize level2 options");
assert!(matches!(
level2_deserialized.optimization_level,
GraphOptimizationLevel::Level2
));

// level3
let level3_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level3);
let level3_serialized = serde_json::to_string(&level3_opts).expect("serialize level3 options");
let level3_deserialized: Options =
serde_json::from_str(&level3_serialized).expect("deserialize level3 options");
assert!(matches!(
level3_deserialized.optimization_level,
GraphOptimizationLevel::Level3
));

// all
let all_opts = Options::default().with_optimization_level(GraphOptimizationLevel::All);
let all_serialized = serde_json::to_string(&all_opts).expect("serialize all options");
let all_deserialized: Options =
serde_json::from_str(&all_serialized).expect("deserialize all options");
assert!(matches!(
all_deserialized.optimization_level,
GraphOptimizationLevel::All
));

// disable
let disable_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Disable);
let disable_serialized =
serde_json::to_string(&disable_opts).expect("serialize disable options");
let disable_deserialized: Options =
serde_json::from_str(&disable_serialized).expect("deserialize disable options");
assert!(matches!(
disable_deserialized.optimization_level,
GraphOptimizationLevel::Disable
));
}
}