From f2d1309dfbfaad9326bfad03477f02149b27dee3 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 09:46:28 +0300 Subject: [PATCH 1/4] Add audio crate dependencies --- Cargo.lock | 47 ++++++++++++++++++++++++++++++++++++++++++ catgrad-llm/Cargo.toml | 2 ++ 2 files changed, 49 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 3c2324ac..8661b1cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -419,6 +419,7 @@ dependencies = [ "graphviz-rust", "half", "hf-hub", + "hound", "image", "log", "memmap2", @@ -427,6 +428,7 @@ dependencies = [ "open-hypergraphs", "open-hypergraphs-dot", "rayon", + "rustfft", "safetensors 0.7.0", "serde", "serde_json", @@ -1483,6 +1485,12 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + [[package]] name = "http" version = "1.4.0" @@ -2435,6 +2443,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "primal-check" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc0d895b311e3af9902528fbb8f928688abbd95872819320517cc24ca6b2bd08" +dependencies = [ + "num-integer", +] + [[package]] name = "proc-macro2" version = "1.0.103" @@ -2740,6 +2757,20 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustfft" +version = "6.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21db5f9893e91f41798c88680037dba611ca6674703c1a18601b01a72c8adb89" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "primal-check", + "strength_reduce", + "transpose", +] + [[package]] name = "rustix" version = "1.1.2" @@ -3004,6 +3035,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.11.1" @@ -3352,6 +3389,16 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "transpose" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad61aed86bc3faea4300c7aee358b4c6d0c8d6ccc36524c96e4c92ccf26e77e" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "typed-builder" version = "0.23.2" diff --git a/catgrad-llm/Cargo.toml b/catgrad-llm/Cargo.toml index 3143227e..7c223487 100644 --- a/catgrad-llm/Cargo.toml +++ b/catgrad-llm/Cargo.toml @@ -32,6 +32,8 @@ serde_with = { version = "3.17", default-features = false, features = ["macros"] serde_path_to_error = "0.1" ureq = "2.12.1" url = "2.5.7" +hound = "3.5.1" +rustfft = "6.4.1" [dev-dependencies] From f36c3c7463dd9e0deb64fbf09cc023c67a665e06 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 09:47:12 +0300 Subject: [PATCH 2/4] Add audio helper to preprocess wav files --- catgrad-llm/src/utils/audio.rs | 224 +++++++++++++++++++++++++++++++++ catgrad-llm/src/utils/mod.rs | 3 + 2 files changed, 227 insertions(+) create mode 100644 catgrad-llm/src/utils/audio.rs diff --git a/catgrad-llm/src/utils/audio.rs b/catgrad-llm/src/utils/audio.rs new file mode 100644 index 00000000..546337b9 --- /dev/null +++ b/catgrad-llm/src/utils/audio.rs @@ -0,0 +1,224 @@ +use crate::{LLMError, Result}; +use hound::{SampleFormat, WavReader}; +use rustfft::{FftPlanner, num_complex::Complex32}; +use std::path::Path; + +// This has Gemma4 specific audio constants as it is only used by the Gemma 4 multimodal models +// Clean up when other models support audio +pub const AUDIO_SAMPLE_RATE: u32 = 16_000; +pub const AUDIO_FEATURE_SIZE: usize = 128; +pub const AUDIO_FRAME_LENGTH: usize = 320; +pub const AUDIO_HOP_LENGTH: usize = 160; +pub const AUDIO_FFT_LENGTH: usize = 512; +pub const AUDIO_MEL_FLOOR: f32 = 1e-3; +const GEMMA4_AUDIO_MAX_SAMPLES: usize = 480_000; +const GEMMA4_AUDIO_PAD_TO_MULTIPLE_OF: usize = 128; + +#[derive(Debug, Clone)] +pub struct PreparedAudioFeatures { + pub features: Vec, + pub feature_shape: Vec, + pub mask: Vec, + pub mask_shape: Vec, + pub num_mel_frames: usize, + pub valid_mel_frames: usize, +} + +pub fn load_wav_file(path: &Path) -> Result> { + let mut reader = + WavReader::open(path).map_err(|err| LLMError::IoError(std::io::Error::other(err)))?; + let spec = reader.spec(); + if spec.channels != 1 { + return Err(LLMError::UnsupportedWireConversion(format!( + "expected mono wav input, found {} channels", + spec.channels + ))); + } + if spec.sample_rate != AUDIO_SAMPLE_RATE { + return Err(LLMError::UnsupportedWireConversion(format!( + "expected {AUDIO_SAMPLE_RATE}Hz wav input, found {}Hz", + spec.sample_rate + ))); + } + + let samples = match spec.sample_format { + SampleFormat::Float => reader + .samples::() + .collect::, _>>() + .map_err(|err| LLMError::IoError(std::io::Error::other(err)))?, + SampleFormat::Int => { + let scale = ((1i64 << (spec.bits_per_sample.saturating_sub(1) as u32)) - 1) as f32; + if scale <= 0.0 { + return Err(LLMError::UnsupportedWireConversion(format!( + "unsupported wav bit depth {}", + spec.bits_per_sample + ))); + } + reader + .samples::() + .map(|sample| { + sample + .map(|sample| sample as f32 / scale) + .map_err(|err| LLMError::IoError(std::io::Error::other(err))) + }) + .collect::>>()? + } + }; + + Ok(samples) +} + +pub fn gemma4_audio_mel_frame_count(num_samples: usize) -> usize { + let padded_samples = num_samples + AUDIO_FRAME_LENGTH / 2; + let frame_size_for_unfold = AUDIO_FRAME_LENGTH + 1; + if padded_samples < frame_size_for_unfold { + 0 + } else { + (padded_samples - frame_size_for_unfold) / AUDIO_HOP_LENGTH + 1 + } +} + +pub fn compute_log_mel_spectrogram(waveform: &[f32]) -> Result<(Vec, Vec, usize)> { + let waveform = if waveform.len() > GEMMA4_AUDIO_MAX_SAMPLES { + &waveform[..GEMMA4_AUDIO_MAX_SAMPLES] + } else { + waveform + }; + let padded_len = round_up_to_multiple(waveform.len(), GEMMA4_AUDIO_PAD_TO_MULTIPLE_OF); + let mut padded_input = vec![0.0f32; padded_len]; + padded_input[..waveform.len()].copy_from_slice(waveform); + let mut sample_mask = vec![false; padded_len]; + sample_mask[..waveform.len()].fill(true); + + let frame_size_for_unfold = AUDIO_FRAME_LENGTH + 1; + let pad_left = AUDIO_FRAME_LENGTH / 2; + let mut padded_waveform = vec![0.0f32; pad_left + padded_len]; + padded_waveform[pad_left..].copy_from_slice(&padded_input); + let mut padded_mask = vec![false; pad_left + padded_len]; + padded_mask[pad_left..].copy_from_slice(&sample_mask); + if padded_waveform.len() < frame_size_for_unfold { + return Ok((Vec::new(), Vec::new(), 0)); + } + + let num_frames = gemma4_audio_mel_frame_count(padded_len); + let window = hann_window(); + let mel_filters = create_mel_filterbank( + AUDIO_FEATURE_SIZE, + AUDIO_FFT_LENGTH, + AUDIO_SAMPLE_RATE as f32, + 0.0, + AUDIO_SAMPLE_RATE as f32 / 2.0, + ); + let mut planner = FftPlanner::::new(); + let fft = planner.plan_fft_forward(AUDIO_FFT_LENGTH); + + let mut features = Vec::with_capacity(num_frames * AUDIO_FEATURE_SIZE); + let mut mask = Vec::with_capacity(num_frames); + for frame_idx in 0..num_frames { + let start = frame_idx * AUDIO_HOP_LENGTH; + let raw_frame = &padded_waveform[start..start + frame_size_for_unfold]; + let mut windowed = raw_frame[..AUDIO_FRAME_LENGTH] + .iter() + .zip(window.iter()) + .map(|(sample, coeff)| Complex32::new(sample * coeff, 0.0)) + .collect::>(); + windowed.resize(AUDIO_FFT_LENGTH, Complex32::new(0.0, 0.0)); + fft.process(&mut windowed); + + let magnitude = windowed[..AUDIO_FFT_LENGTH / 2 + 1] + .iter() + .map(|complex| complex.norm()) + .collect::>(); + + let valid = padded_mask[start + frame_size_for_unfold - 1]; + for filter in &mel_filters { + let mel = filter + .iter() + .enumerate() + .fold(0.0f32, |acc, (freq_idx, coeff)| { + acc + magnitude[freq_idx] * coeff + }); + features.push(if valid { + (mel + AUDIO_MEL_FLOOR).ln() + } else { + 0.0 + }); + } + mask.push(if valid { 0.0 } else { 1.0 }); + } + + Ok((features, mask, num_frames)) +} + +pub fn prepare_audio_features(path: &Path) -> Result { + let waveform = load_wav_file(path)?; + let (features, mask, num_mel_frames) = compute_log_mel_spectrogram(&waveform)?; + if num_mel_frames == 0 { + return Err(LLMError::UnsupportedWireConversion( + "audio input produced no log-mel frames".to_string(), + )); + } + let valid_mel_frames = mask.iter().filter(|&&value| value == 0.0).count(); + Ok(PreparedAudioFeatures { + feature_shape: vec![1, num_mel_frames, AUDIO_FEATURE_SIZE], + features, + mask_shape: vec![1, num_mel_frames], + mask, + num_mel_frames, + valid_mel_frames, + }) +} + +fn hann_window() -> Vec { + let arg = std::f32::consts::PI * 2.0 / AUDIO_FRAME_LENGTH as f32; + (0..AUDIO_FRAME_LENGTH) + .map(|idx| 0.5 - 0.5 * (arg * idx as f32).cos()) + .collect() +} + +fn create_mel_filterbank( + n_mels: usize, + n_fft: usize, + sample_rate: f32, + min_frequency: f32, + max_frequency: f32, +) -> Vec> { + let n_freqs = n_fft / 2 + 1; + let all_freqs = (0..n_freqs) + .map(|idx| idx as f32 * sample_rate / n_fft as f32) + .collect::>(); + + let hz_to_mel = |hz: f32| 2595.0 * (1.0 + hz / 700.0).log10(); + let mel_to_hz = |mel: f32| 700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0); + + let min_mel = hz_to_mel(min_frequency); + let max_mel = hz_to_mel(max_frequency); + let mut freq_points = Vec::with_capacity(n_mels + 2); + for idx in 0..(n_mels + 2) { + let mel = min_mel + (max_mel - min_mel) * idx as f32 / (n_mels + 1) as f32; + freq_points.push(mel_to_hz(mel)); + } + + let mut filterbank = vec![vec![0.0; n_freqs]; n_mels]; + for mel_idx in 0..n_mels { + let left = freq_points[mel_idx]; + let center = freq_points[mel_idx + 1]; + let right = freq_points[mel_idx + 2]; + for (freq_idx, &freq) in all_freqs.iter().enumerate() { + if freq >= left && freq <= center && center > left { + filterbank[mel_idx][freq_idx] = (freq - left) / (center - left); + } else if freq > center && freq <= right && right > center { + filterbank[mel_idx][freq_idx] = (right - freq) / (right - center); + } + } + } + filterbank +} + +fn round_up_to_multiple(value: usize, multiple: usize) -> usize { + if value == 0 || multiple == 0 { + value + } else { + value.div_ceil(multiple) * multiple + } +} diff --git a/catgrad-llm/src/utils/mod.rs b/catgrad-llm/src/utils/mod.rs index 68dfe33b..42387e61 100644 --- a/catgrad-llm/src/utils/mod.rs +++ b/catgrad-llm/src/utils/mod.rs @@ -25,6 +25,9 @@ mod images; pub(crate) use images::convert_image_to_patches; pub use images::*; +mod audio; +pub use audio::*; + fn build_hf_api() -> Result { let mut builder = ApiBuilder::from_env(); let env_token = std::env::var("HF_TOKEN") From 0ec7c9759195cccab0b026973208f682bf1535d0 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 09:48:29 +0300 Subject: [PATCH 3/4] llm.py: accept wav file inputs to test audio models (Gemma4) --- catgrad-llm/scripts/llm.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/catgrad-llm/scripts/llm.py b/catgrad-llm/scripts/llm.py index 07f262e1..aa804a2d 100644 --- a/catgrad-llm/scripts/llm.py +++ b/catgrad-llm/scripts/llm.py @@ -10,7 +10,7 @@ from transformers import ( AutoModelForCausalLM, - AutoModelForImageTextToText, + AutoModelForMultimodalLM, AutoProcessor, AutoTokenizer, logging, @@ -247,6 +247,7 @@ def run_tool_chat(tokenizer, model, prompt, args): parser.add_argument("-p", "--prompt", type=str, default="Category theory is") parser.add_argument("-s", "--seq-len", type=int, default=10) parser.add_argument("-i", "--image", type=str, default=None) + parser.add_argument("-a", "--audio", type=str, default=None) parser.add_argument("-r", "--raw", action="store_true") parser.add_argument("-t", "--thinking", action="store_true") parser.add_argument( @@ -266,19 +267,19 @@ def run_tool_chat(tokenizer, model, prompt, args): if args.tool_use and args.raw: parser.error("--tool-use does not support --raw") - if args.image is None: + if args.image is None and args.audio is None: tokenizer = AutoTokenizer.from_pretrained(args.model, revision=args.revision) try: model = AutoModelForCausalLM.from_pretrained( args.model, revision=args.revision, dtype=args.dtype ) except: - model = AutoModelForImageTextToText.from_pretrained( + model = AutoModelForMultimodalLM.from_pretrained( args.model, revision=args.revision, dtype=args.dtype ) else: processor = AutoProcessor.from_pretrained(args.model, revision=args.revision) - model = AutoModelForImageTextToText.from_pretrained( + model = AutoModelForMultimodalLM.from_pretrained( args.model, revision=args.revision, dtype=args.dtype ) @@ -290,6 +291,7 @@ def run_tool_chat(tokenizer, model, prompt, args): if ( args.image is None + and args.audio is None and not args.raw and not args.tool_use and tokenizer.chat_template is not None @@ -309,7 +311,7 @@ def run_tool_chat(tokenizer, model, prompt, args): model.generation_config.top_p = None model.generation_config.top_k = None - if args.image is None: + if args.image is None and args.audio is None: if args.tool_use: output = run_tool_chat(tokenizer, model, prompt, args) else: @@ -322,13 +324,16 @@ def run_tool_chat(tokenizer, model, prompt, args): ) output = tokenizer.decode(logits[0], skip_special_tokens=True) else: + content = [{"type": "text", "text": prompt}] + if args.image: + content += [{"type": "image", "path": args.image}] + if args.audio: + content += [{"type": "audio", "path": args.audio}] + messages = [ { "role": "user", - "content": [ - {"type": "text", "text": prompt}, - {"type": "image", "path": args.image}, - ], + "content": content, } ] try: From 7cb00de18cb479eec893f1c74e9a89307ec8f9ee Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 09:50:04 +0300 Subject: [PATCH 4/4] Granite-4 tool calling parser --- catgrad-llm/src/helpers/tool_calls.rs | 4 ++++ catgrad-llm/src/models/granite.rs | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/catgrad-llm/src/helpers/tool_calls.rs b/catgrad-llm/src/helpers/tool_calls.rs index bd19b8d2..c554356a 100644 --- a/catgrad-llm/src/helpers/tool_calls.rs +++ b/catgrad-llm/src/helpers/tool_calls.rs @@ -50,6 +50,10 @@ pub fn parse_qwen3_5_tool_calls(output: &str) -> Result> { ) } +pub fn parse_granite_tool_calls(output: &str) -> Result> { + parse_qwen3_tool_calls(output) +} + pub fn parse_lfm2_tool_calls(output: &str) -> Result> { parse_python_tool_calls(output, "<|tool_call_start|>", "<|tool_call_end|>") } diff --git a/catgrad-llm/src/models/granite.rs b/catgrad-llm/src/models/granite.rs index b746c9e0..d1a213ef 100644 --- a/catgrad-llm/src/models/granite.rs +++ b/catgrad-llm/src/models/granite.rs @@ -77,6 +77,10 @@ impl LLMModel for GraniteModel { fn dtype(&self) -> Dtype { self.dtype } + + fn parse_tool_calls(&self, output: &str) -> crate::Result> { + parse_granite_tool_calls(output) + } } impl GraniteModel {