From 1e0ab515aa783c31354803729e2fa884cad3cc46 Mon Sep 17 00:00:00 2001 From: uqio <276879906+uqio@users.noreply.github.com> Date: Tue, 21 Apr 2026 15:42:26 +1200 Subject: [PATCH] feat(serde): support serde on `soundevents::Options` --- soundevents/Cargo.toml | 9 ++- soundevents/src/lib.rs | 149 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 1 deletion(-) diff --git a/soundevents/Cargo.toml b/soundevents/Cargo.toml index 4e35cbb..92dcbb3 100644 --- a/soundevents/Cargo.toml +++ b/soundevents/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "soundevents" -version = "0.2.0" +version = "0.2.1" edition = "2024" description = "Production-oriented Rust inference wrapper for CED AudioSet classifiers." license.workspace = true @@ -28,6 +28,8 @@ include = [ default = [] bundled-tiny = [] +serde = ["dep:serde"] + [dependencies] ort = "2.0.0-rc.12" smol_str = "0.3" @@ -35,6 +37,11 @@ thiserror = { workspace = true, features = ["default"] } soundevents-dataset = { workspace = true, features = ["rated"] } +serde = { workspace = true, optional = true, features = ["derive"] } + +[dev-dependencies] +serde_json = "1" + [lints.rust] rust_2018_idioms = "warn" single_use_lifetimes = "warn" diff --git a/soundevents/src/lib.rs b/soundevents/src/lib.rs index 957a32e..4a4df7c 100644 --- a/soundevents/src/lib.rs +++ b/soundevents/src/lib.rs @@ -16,6 +16,9 @@ use std::{ path::{Path, PathBuf}, }; +#[cfg(feature = "serde")] +use serde::{Deserialize, Serialize}; + /// The expected input sample rate for CED models. pub const SAMPLE_RATE_HZ: usize = 16_000; @@ -29,10 +32,85 @@ pub const NUM_CLASSES: usize = RatedSoundEvent::events().len(); const BUNDLED_TINY_MODEL: &[u8] = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/models/tiny.onnx")); +#[cfg(feature = "serde")] +mod graph_optimization_level { + use super::GraphOptimizationLevel; + use serde::*; + + #[derive( + Debug, Default, Clone, Copy, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize, Deserialize, + )] + #[serde(rename_all = "snake_case")] + enum OptimizationLevel { + #[default] + Disable, + Level1, + Level2, + Level3, + All, + } + + impl From for OptimizationLevel { + #[inline] + fn from(value: GraphOptimizationLevel) -> Self { + match value { + GraphOptimizationLevel::Disable => Self::Disable, + GraphOptimizationLevel::Level1 => Self::Level1, + GraphOptimizationLevel::Level2 => Self::Level2, + GraphOptimizationLevel::Level3 => Self::Level3, + GraphOptimizationLevel::All => Self::All, + } + } + } + + impl From for GraphOptimizationLevel { + #[inline] + fn from(value: OptimizationLevel) -> Self { + match value { + OptimizationLevel::Disable => Self::Disable, + OptimizationLevel::Level1 => Self::Level1, + OptimizationLevel::Level2 => Self::Level2, + OptimizationLevel::Level3 => Self::Level3, + OptimizationLevel::All => Self::All, + } + } + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub fn serialize(level: &GraphOptimizationLevel, serializer: S) -> Result + where + S: Serializer, + { + OptimizationLevel::from(*level).serialize(serializer) + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + OptimizationLevel::deserialize(deserializer).map(Into::into) + } + + #[cfg_attr(not(tarpaulin), inline(always))] + pub const fn default() -> GraphOptimizationLevel { + GraphOptimizationLevel::Disable + } +} + /// Options for constructing a [`Classifier`] from an ONNX model on disk. #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Options { + #[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))] model_path: Option, + #[cfg_attr( + feature = "serde", + serde( + default = "graph_optimization_level::default", + with = "graph_optimization_level" + ) + )] optimization_level: GraphOptimizationLevel, } @@ -1565,4 +1643,75 @@ mod tests { Restriction::try_from("bogus").expect_err("unknown token surfaced"); assert_eq!(err.name(), "bogus"); } + + #[cfg(feature = "serde")] + #[test] + fn test_serde() { + let opts = Options::default() + .with_model_path("some/model.onnx") + .with_optimization_level(GraphOptimizationLevel::Level2); + let serialized = serde_json::to_string(&opts).expect("serialize options"); + let deserialized: Options = serde_json::from_str(&serialized).expect("deserialize options"); + assert_eq!(opts.model_path, deserialized.model_path); + assert_eq!(opts.optimization_level, deserialized.optimization_level); + + let default_deserialized: Options = + serde_json::from_str("{}").expect("deserialize default options"); + assert!(default_deserialized.model_path.is_none()); + assert!(matches!( + default_deserialized.optimization_level, + GraphOptimizationLevel::Disable + )); + + // level1 + let level1_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level1); + let level1_serialized = serde_json::to_string(&level1_opts).expect("serialize level1 options"); + let level1_deserialized: Options = + serde_json::from_str(&level1_serialized).expect("deserialize level1 options"); + assert!(matches!( + level1_deserialized.optimization_level, + GraphOptimizationLevel::Level1 + )); + + // level2 + let level2_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level2); + let level2_serialized = serde_json::to_string(&level2_opts).expect("serialize level2 options"); + let level2_deserialized: Options = + serde_json::from_str(&level2_serialized).expect("deserialize level2 options"); + assert!(matches!( + level2_deserialized.optimization_level, + GraphOptimizationLevel::Level2 + )); + + // level3 + let level3_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Level3); + let level3_serialized = serde_json::to_string(&level3_opts).expect("serialize level3 options"); + let level3_deserialized: Options = + serde_json::from_str(&level3_serialized).expect("deserialize level3 options"); + assert!(matches!( + level3_deserialized.optimization_level, + GraphOptimizationLevel::Level3 + )); + + // all + let all_opts = Options::default().with_optimization_level(GraphOptimizationLevel::All); + let all_serialized = serde_json::to_string(&all_opts).expect("serialize all options"); + let all_deserialized: Options = + serde_json::from_str(&all_serialized).expect("deserialize all options"); + assert!(matches!( + all_deserialized.optimization_level, + GraphOptimizationLevel::All + )); + + // disable + let disable_opts = Options::default().with_optimization_level(GraphOptimizationLevel::Disable); + let disable_serialized = + serde_json::to_string(&disable_opts).expect("serialize disable options"); + let disable_deserialized: Options = + serde_json::from_str(&disable_serialized).expect("deserialize disable options"); + assert!(matches!( + disable_deserialized.optimization_level, + GraphOptimizationLevel::Disable + )); + } }