From 936e5f5a5f57faee48ebcfd3dfa006746e62b0e6 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 10:23:33 +0300 Subject: [PATCH 1/3] Add Conformer encoder model --- catgrad-llm/src/models/conformer.rs | 1111 +++++++++++++++++++++++++++ catgrad-llm/src/models/mod.rs | 1 + 2 files changed, 1112 insertions(+) create mode 100644 catgrad-llm/src/models/conformer.rs diff --git a/catgrad-llm/src/models/conformer.rs b/catgrad-llm/src/models/conformer.rs new file mode 100644 index 00000000..f0b97385 --- /dev/null +++ b/catgrad-llm/src/models/conformer.rs @@ -0,0 +1,1111 @@ +// Conformer model used by Gemma4 Audio +#![allow(clippy::too_many_arguments)] +use crate::helpers::*; +use catgrad::prelude::ops::*; +use catgrad::prelude::*; +use nn::*; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct Gemma4AudioConfig { + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub hidden_act: String, + pub subsampling_conv_channels: Vec, + pub conv_kernel_size: usize, + pub residual_weight: f32, + pub attention_chunk_size: usize, + pub attention_context_left: usize, + pub attention_context_right: usize, + pub attention_logit_cap: f32, + pub attention_invalid_logits_value: f32, + pub use_clipped_linears: bool, + pub rms_norm_eps: f32, + pub gradient_clipping: f32, + pub output_proj_dims: usize, +} + +impl Gemma4AudioConfig { + pub fn num_soft_tokens_for_frames(&self, frames: usize) -> usize { + let mut frames = frames; + for _ in 0..2 { + frames = frames.div_ceil(2); + } + frames.min(750) + } +} + +#[derive(Debug, Clone)] +pub struct Gemma4AudioTower { + pub config: Gemma4AudioConfig, + pub input_time_steps: usize, + pub input_feature_bins: usize, +} + +impl Gemma4AudioTower { + pub fn audio_model(&self, builder: &Builder, features: Var, mask: Var) -> Var { + let mut x = reshape( + builder, + shape!( + builder, + 1, + 1, + self.input_time_steps, + self.input_feature_bins + ), + features, + ); + let mut mask = reshape(builder, shape!(builder, 1, self.input_time_steps), mask); + (x, mask) = self.sscp_block( + builder, + path(vec![ + "model", + "audio_tower", + "subsample_conv_projection", + "layer0", + ]) + .unwrap(), + 1, + self.config.subsampling_conv_channels[0], + self.input_time_steps, + self.input_feature_bins, + x, + mask, + ); + let time_after_layer0 = self.input_time_steps.div_ceil(2); + let freq_after_layer0 = self.input_feature_bins.div_ceil(2); + (x, mask) = self.sscp_block( + builder, + path(vec![ + "model", + "audio_tower", + "subsample_conv_projection", + "layer1", + ]) + .unwrap(), + self.config.subsampling_conv_channels[0], + self.config.subsampling_conv_channels[1], + time_after_layer0, + freq_after_layer0, + x, + mask, + ); + + let time_after_layer1 = time_after_layer0.div_ceil(2); + let freq_after_layer1 = freq_after_layer0.div_ceil(2); + let x = transpose(builder, 1, 2, x); + let x = transpose(builder, 2, 3, x); + let x = reshape( + builder, + shape!( + builder, + 1, + time_after_layer1, + freq_after_layer1 * self.config.subsampling_conv_channels[1] + ), + x, + ); + let mut x = linear_no_bias( + builder, + freq_after_layer1 * self.config.subsampling_conv_channels[1], + self.config.hidden_size, + path(vec![ + "model", + "audio_tower", + "subsample_conv_projection", + "input_proj_linear", + ]) + .unwrap(), + x, + ); + let position_embeddings = relative_position_embeddings( + builder, + self.config.hidden_size, + self.config.attention_context_left.saturating_sub(1) + + self.config.attention_context_right + + 1, + ); + let attention_mask = blocked_bidirectional_attention_mask( + builder, + time_after_layer1, + self.config.attention_chunk_size, + self.config.attention_context_left.saturating_sub(1), + self.config.attention_context_right, + mask, + ); + + for layer_id in 0..self.config.num_hidden_layers { + let layer_path = path(vec!["model", "audio_tower", "layers"]) + .unwrap() + .extend([&layer_id.to_string()]) + .unwrap(); + x = self.conformer_block( + builder, + layer_path, + time_after_layer1, + attention_mask.clone(), + position_embeddings.clone(), + x, + ); + } + + linear( + builder, + self.config.hidden_size, + self.config.output_proj_dims, + path(vec!["model", "audio_tower", "output_proj"]).unwrap(), + x, + ) + } + + fn conformer_block( + &self, + builder: &Builder, + p: Path, + seq_len: usize, + attention_mask: Var, + position_embeddings: Var, + x: Var, + ) -> Var { + let x = self.feed_forward(builder, p.extend(["feed_forward1"]).unwrap(), x); + let residual = x.clone(); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["norm_pre_attn"]).unwrap(), + x, + ); + let x = self.attention_block( + builder, + p.clone(), + seq_len, + attention_mask, + position_embeddings, + x, + ); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["norm_post_attn"]).unwrap(), + x, + ); + let x = residual + x; + let x = self.light_conv(builder, p.extend(["lconv1d"]).unwrap(), x); + let x = self.feed_forward(builder, p.extend(["feed_forward2"]).unwrap(), x); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["norm_out"]).unwrap(), + x, + ) + } + + fn feed_forward(&self, builder: &Builder, p: Path, x: Var) -> Var { + let residual = x.clone(); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["pre_layer_norm"]).unwrap(), + x, + ); + let x = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size * 4, + p.extend(["ffw_layer_1"]).unwrap(), + x, + ); + let x = activation(builder, &self.config.hidden_act, x); + let x = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size * 4, + self.config.hidden_size, + p.extend(["ffw_layer_2"]).unwrap(), + x, + ); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["post_layer_norm"]).unwrap(), + x, + ); + let scale = constant( + builder, + self.config.residual_weight, + &shape(builder, x.clone()), + ); + let scale = cast(builder, scale, dtype(builder, x.clone())); + residual + x * scale + } + + fn attention_block( + &self, + builder: &Builder, + p: Path, + seq_len: usize, + attention_mask: Var, + position_embeddings: Var, + x: Var, + ) -> Var { + let x = self.local_attention( + builder, + p.extend(["self_attn"]).unwrap(), + seq_len, + attention_mask, + position_embeddings, + x, + ); + let x = reshape( + builder, + shape!(builder, 1, seq_len, self.config.hidden_size), + x, + ); + clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size, + p.extend(["self_attn", "post"]).unwrap(), + x, + ) + } + + fn local_attention( + &self, + builder: &Builder, + p: Path, + seq_len: usize, + attention_mask: Var, + position_embeddings: Var, + x: Var, + ) -> Var { + let num_heads = self.config.num_attention_heads; + let head_dim = self.config.hidden_size / num_heads; + let chunk_size = self.config.attention_chunk_size; + let max_past = self.config.attention_context_left.saturating_sub(1); + let max_future = self.config.attention_context_right; + let context_size = chunk_size + max_past + max_future; + let num_blocks = seq_len.div_ceil(chunk_size); + let q_scale = head_dim as f32; + let q_scale = q_scale.powf(-0.5) / std::f32::consts::LN_2; + let k_scale = (1.0f32 + std::f32::consts::E).ln() / std::f32::consts::LN_2; + + let q = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size, + p.extend(["q_proj"]).unwrap(), + x.clone(), + ); + let k = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size, + p.extend(["k_proj"]).unwrap(), + x.clone(), + ); + let v = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size, + p.extend(["v_proj"]).unwrap(), + x, + ); + + let q = reshape(builder, shape!(builder, 1, seq_len, num_heads, head_dim), q); + let k = reshape(builder, shape!(builder, 1, seq_len, num_heads, head_dim), k); + let v = reshape(builder, shape!(builder, 1, seq_len, num_heads, head_dim), v); + + let q_scale = constant(builder, q_scale, &shape(builder, q.clone())); + let q_scale = cast(builder, q_scale, dtype(builder, q.clone())); + let q = q * q_scale; + let per_dim_scale = param(builder, &p.extend(["per_dim_scale"]).unwrap()); + let per_dim_scale = softplus(builder, per_dim_scale); + let per_dim_scale = reshape(builder, shape!(builder, 1, 1, 1, head_dim), per_dim_scale); + let q = q.clone() * broadcast(builder, shape(builder, q), per_dim_scale); + let k_scale = constant(builder, k_scale, &shape(builder, k.clone())); + let k_scale = cast(builder, k_scale, dtype(builder, k.clone())); + let k = k * k_scale; + + let q_blocks = blockify_4d(builder, seq_len, chunk_size, num_heads, head_dim, q); + let k_blocks = extract_block_context_4d( + builder, + seq_len, + max_past, + max_future, + chunk_size, + context_size, + num_heads, + head_dim, + k, + ); + let v_blocks = extract_block_context_4d( + builder, + seq_len, + max_past, + max_future, + chunk_size, + context_size, + num_heads, + head_dim, + v, + ); + + let mut logits = relative_attention_logits( + builder, + p, + position_embeddings, + num_blocks, + chunk_size, + context_size, + num_heads, + head_dim, + q_blocks, + k_blocks, + ); + let softcap = constant( + builder, + self.config.attention_logit_cap, + &shape(builder, logits.clone()), + ); + let softcap = cast(builder, softcap, dtype(builder, logits.clone())); + logits = tanh(builder, logits / softcap.clone()) * softcap; + + let logits = masked_fill( + builder, + broadcast(builder, shape(builder, logits.clone()), attention_mask), + self.config.attention_invalid_logits_value, + logits, + ); + let attn = softmax(builder, logits); + let mut head_outputs = Vec::with_capacity(num_heads); + for head in 0..num_heads { + let attn_head = slice(builder, 1, head, 1, attn.clone()); + let attn_head = squeeze::<5, 4>(builder, 1, attn_head); + let value_head = slice(builder, 3, head, 1, v_blocks.clone()); + let value_head = squeeze::<5, 4>(builder, 3, value_head); + + let mut block_outputs = Vec::with_capacity(num_blocks); + for block in 0..num_blocks { + let attn_block = slice(builder, 1, block, 1, attn_head.clone()); + let attn_block = squeeze::<4, 3>(builder, 1, attn_block); + let value_block = slice(builder, 1, block, 1, value_head.clone()); + let value_block = squeeze::<4, 3>(builder, 1, value_block); + let out = matmul(builder, attn_block, value_block); + block_outputs.push(unsqueeze::<3, 4>(builder, 1, out)); + } + + let head_output = block_outputs + .into_iter() + .reduce(|acc, item| concat(builder, 1, acc, item)) + .unwrap(); + head_outputs.push(unsqueeze::<4, 5>(builder, 3, head_output)); + } + + let attn = head_outputs + .into_iter() + .reduce(|acc, item| concat(builder, 3, acc, item)) + .unwrap(); + let attn = reshape( + builder, + shape!(builder, 1, num_blocks * chunk_size, num_heads * head_dim), + attn, + ); + slice(builder, 1, 0, seq_len, attn) + } + + fn light_conv(&self, builder: &Builder, p: Path, x: Var) -> Var { + let residual = x.clone(); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["pre_layer_norm"]).unwrap(), + x, + ); + let x = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size * 2, + p.extend(["linear_start"]).unwrap(), + x, + ); + let [value, gate] = chunk(builder, 2, 2, self.config.hidden_size, x) + .try_into() + .unwrap(); + let x = value * sigmoid(builder, gate); + let x = transpose(builder, 1, 2, x); + let x = depthwise_conv1d_no_bias( + builder, + p.extend(["depthwise_conv1d"]).unwrap(), + self.config.conv_kernel_size, + x, + self.config.conv_kernel_size - 1, + ); + let x = transpose(builder, 1, 2, x); + let x = clamp( + builder, + x, + -self.config.gradient_clipping, + self.config.gradient_clipping, + ); + let x = rmsnorm::<3>( + builder, + self.config.rms_norm_eps, + p.extend(["conv_norm"]).unwrap(), + x, + ); + let x = activation(builder, &self.config.hidden_act, x); + let x = clippable_linear_no_bias( + builder, + self.config.use_clipped_linears, + self.config.hidden_size, + self.config.hidden_size, + p.extend(["linear_end"]).unwrap(), + x, + ); + residual + x + } + + fn sscp_block( + &self, + builder: &Builder, + p: Path, + in_channels: usize, + out_channels: usize, + time_steps: usize, + freq_bins: usize, + x: Var, + mask: Var, + ) -> (Var, Var) { + let out_time = time_steps.div_ceil(2); + let out_freq = freq_bins.div_ceil(2); + let x_dtype = dtype(builder, x.clone()); + let zero = constant(builder, 0.0, &shape(builder, mask.clone())); + let valid = eq(builder, mask.clone(), zero); + let valid = cast(builder, valid, x_dtype); + let valid = reshape(builder, shape!(builder, 1, 1, time_steps, 1), valid); + let valid = broadcast( + builder, + shape!(builder, 1, in_channels, time_steps, freq_bins), + valid, + ); + let x = x * valid; + let x = conv2d_stride2_square_no_bias( + builder, + p.extend(["conv"]).unwrap(), + in_channels, + out_channels, + time_steps, + freq_bins, + x, + ); + let mask = subsample_mask(builder, time_steps, out_time, mask); + let x = transpose(builder, 1, 2, x); + let x = transpose(builder, 2, 3, x); + let x = reshape( + builder, + shape!(builder, 1, out_time * out_freq, out_channels), + x, + ); + let x = layernorm_weight_only( + builder, + self.config.rms_norm_eps, + p.extend(["norm"]).unwrap(), + x, + ); + let x = reshape( + builder, + shape!(builder, 1, out_time, out_freq, out_channels), + x, + ); + let x = relu(builder, x); + let x = transpose(builder, 2, 3, x); + (transpose(builder, 1, 2, x), mask) + } +} + +fn activation(builder: &Builder, hidden_act: &str, x: Var) -> Var { + match hidden_act { + "silu" => silu(builder, x), + other => panic!("unsupported Gemma4 audio activation `{other}`"), + } +} + +fn relu(builder: &Builder, x: Var) -> Var { + clamp(builder, x, 0.0, f32::MAX) +} + +fn layernorm_weight_only(builder: &Builder, eps: f32, p: Path, x: Var) -> Var { + let x = layernorm_raw(builder, eps, x); + let weight = param(builder, &p.extend(["weight"]).unwrap()); + x.clone() * broadcast(builder, shape(builder, x), weight) +} + +fn clippable_linear_no_bias( + builder: &Builder, + use_clipped_linears: bool, + in_features: usize, + out_features: usize, + p: Path, + x: Var, +) -> Var { + let x = if use_clipped_linears { + let sh = shape(builder, x.clone()); + let input_min = broadcast( + builder, + sh.clone(), + param(builder, &p.extend(["input_min"]).unwrap()), + ); + let input_max = broadcast( + builder, + sh, + param(builder, &p.extend(["input_max"]).unwrap()), + ); + clamp_with_tensors(builder, x, input_min, input_max) + } else { + x + }; + let x = linear_no_bias( + builder, + in_features, + out_features, + p.extend(["linear"]).unwrap(), + x, + ); + if use_clipped_linears { + let sh = shape(builder, x.clone()); + let output_min = broadcast( + builder, + sh.clone(), + param(builder, &p.extend(["output_min"]).unwrap()), + ); + let output_max = broadcast( + builder, + sh, + param(builder, &p.extend(["output_max"]).unwrap()), + ); + clamp_with_tensors(builder, x, output_min, output_max) + } else { + x + } +} + +fn relative_attention_logits( + builder: &Builder, + p: Path, + position_embeddings: Var, + num_blocks: usize, + chunk_size: usize, + context_size: usize, + num_heads: usize, + head_dim: usize, + q_blocks: Var, + k_blocks: Var, +) -> Var { + let rel_k = linear_no_bias( + builder, + num_heads * head_dim, + num_heads * head_dim, + p.extend(["relative_k_proj"]).unwrap(), + position_embeddings, + ); + let rel_count = context_size - chunk_size + 1; + let rel_k = reshape( + builder, + shape!(builder, 1, rel_count, num_heads, head_dim), + rel_k, + ); + let rel_k = transpose(builder, 1, 2, rel_k); + let rel_k = transpose(builder, 2, 3, rel_k); + + let q_ac = transpose(builder, 1, 3, q_blocks); + let q_ac = transpose(builder, 2, 3, q_ac); + let k_ac = transpose(builder, 1, 3, k_blocks); + let k_ac = transpose(builder, 2, 3, k_ac); + let k_ac = transpose(builder, 3, 4, k_ac); + + let mut head_terms = Vec::with_capacity(num_heads); + for head in 0..num_heads { + let q_head = slice(builder, 1, head, 1, q_ac.clone()); + let q_head = squeeze::<5, 4>(builder, 1, q_head); + let k_head = slice(builder, 1, head, 1, k_ac.clone()); + let k_head = squeeze::<5, 4>(builder, 1, k_head); + let rel_head = slice(builder, 1, head, 1, rel_k.clone()); + let rel_head = squeeze::<4, 3>(builder, 1, rel_head); + + let mut block_terms = Vec::with_capacity(num_blocks); + for block in 0..num_blocks { + let q_block = slice(builder, 1, block, 1, q_head.clone()); + let q_block = squeeze::<4, 3>(builder, 1, q_block); + let k_block = slice(builder, 1, block, 1, k_head.clone()); + let k_block = squeeze::<4, 3>(builder, 1, k_block); + let term_ac = matmul(builder, q_block.clone(), k_block); + + let q_block_flat = reshape(builder, shape!(builder, 1, chunk_size, head_dim), q_block); + let term_bd = matmul(builder, q_block_flat, rel_head.clone()); + let term_bd = reshape(builder, shape!(builder, 1, chunk_size, rel_count), term_bd); + let term_bd = relative_shift_single_head(builder, chunk_size, context_size, term_bd); + + let term = term_ac + term_bd; + block_terms.push(unsqueeze::<3, 4>(builder, 1, term)); + } + + let head_term = block_terms + .into_iter() + .reduce(|acc, item| concat(builder, 1, acc, item)) + .unwrap(); + head_terms.push(unsqueeze::<4, 5>(builder, 1, head_term)); + } + + head_terms + .into_iter() + .reduce(|acc, item| concat(builder, 1, acc, item)) + .unwrap() +} + +fn relative_shift_single_head( + builder: &Builder, + chunk_size: usize, + context_size: usize, + term_bd: Var, +) -> Var { + let rel_count = context_size - chunk_size + 1; + let pad = context_size + 1 - rel_count; + let term_bd = if pad > 0 { + let zeros = zeros( + builder, + &shape!(builder, 1, chunk_size, pad), + dtype(builder, term_bd.clone()), + ); + concat(builder, 2, term_bd, zeros) + } else { + term_bd + }; + let term_bd = reshape( + builder, + shape!(builder, 1, chunk_size * (context_size + 1)), + term_bd, + ); + let term_bd = slice(builder, 1, 0, chunk_size * context_size, term_bd); + reshape( + builder, + shape!(builder, 1, chunk_size, context_size), + term_bd, + ) +} + +fn relative_position_embeddings(builder: &Builder, hidden_size: usize, count: usize) -> Var { + let positions = cast(builder, arange(builder, count), Dtype::F32); + let positions = constant( + builder, + (count.saturating_sub(1)) as f32, + &shape(builder, positions.clone()), + ) - positions; + let positions = reshape(builder, shape!(builder, 1, count, 1), positions); + + let num_timescales = hidden_size / 2; + let min_timescale = 1.0f32; + let max_timescale = 10_000.0f32; + let log_increment = + (max_timescale / min_timescale).ln() / num_timescales.saturating_sub(1).max(1) as f32; + let inv_idx = cast(builder, arange(builder, num_timescales), Dtype::F32); + let inv_idx = constant(builder, -log_increment, &shape(builder, inv_idx.clone())) * inv_idx; + let inv_timescales = exp(builder, inv_idx); + let inv_timescales = + inv_timescales.clone() * constant(builder, min_timescale, &shape(builder, inv_timescales)); + let inv_timescales = reshape( + builder, + shape!(builder, 1, 1, num_timescales), + inv_timescales, + ); + let scaled_time_shape = shape!(builder, 1, count, num_timescales); + let positions = broadcast(builder, scaled_time_shape.clone(), positions); + let inv_timescales = broadcast(builder, scaled_time_shape, inv_timescales); + let scaled_time = positions * inv_timescales; + concat( + builder, + 2, + sin(builder, scaled_time.clone()), + cos(builder, scaled_time), + ) +} + +fn sliding_window_valid_mask( + builder: &Builder, + seq_len: usize, + max_past: usize, + max_future: usize, +) -> Var { + let idx = cast(builder, arange(builder, seq_len), Dtype::F32); + let row = reshape(builder, shape!(builder, seq_len, 1), idx.clone()); + let row = broadcast(builder, shape!(builder, seq_len, seq_len), row); + let col = reshape(builder, shape!(builder, 1, seq_len), idx); + let col = broadcast(builder, shape!(builder, seq_len, seq_len), col); + let past_ok = lte( + builder, + row.clone(), + col.clone() + constant(builder, max_past as f32, &shape!(builder, seq_len, seq_len)), + ); + let future_ok = lte( + builder, + col, + row + constant( + builder, + max_future as f32, + &shape!(builder, seq_len, seq_len), + ), + ); + cast(builder, past_ok, Dtype::F32) * cast(builder, future_ok, Dtype::F32) +} + +fn subsample_mask(builder: &Builder, in_len: usize, out_len: usize, mask: Var) -> Var { + let items = (0..out_len) + .map(|idx| { + let start = (idx * 2).min(in_len.saturating_sub(1)); + slice(builder, 1, start, 1, mask.clone()) + }) + .collect::>(); + items + .into_iter() + .reduce(|acc, item| concat(builder, 1, acc, item)) + .unwrap() +} + +fn pad_time_4d( + builder: &Builder, + num_heads: usize, + head_dim: usize, + left: usize, + right: usize, + x: Var, +) -> Var { + let x_dtype = dtype(builder, x.clone()); + let left_pad = zeros( + builder, + &shape!(builder, 1, left, num_heads, head_dim), + x_dtype.clone(), + ); + let right_pad = zeros( + builder, + &shape!(builder, 1, right, num_heads, head_dim), + x_dtype, + ); + let x = concat(builder, 1, left_pad, x); + concat(builder, 1, x, right_pad) +} + +fn pad_time_2d(builder: &Builder, left: usize, right: usize, x: Var) -> Var { + let x_dtype = dtype(builder, x.clone()); + let left_pad = zeros(builder, &shape!(builder, 1, left), x_dtype.clone()); + let right_pad = zeros(builder, &shape!(builder, 1, right), x_dtype); + let x = concat(builder, 1, left_pad, x); + concat(builder, 1, x, right_pad) +} + +fn blockify_4d( + builder: &Builder, + seq_len: usize, + chunk_size: usize, + num_heads: usize, + head_dim: usize, + x: Var, +) -> Var { + let num_blocks = seq_len.div_ceil(chunk_size); + let pad = num_blocks * chunk_size - seq_len; + let x = if pad > 0 { + pad_time_4d(builder, num_heads, head_dim, 0, pad, x) + } else { + x + }; + reshape( + builder, + shape!(builder, 1, num_blocks, chunk_size, num_heads, head_dim), + x, + ) +} + +fn extract_block_context_4d( + builder: &Builder, + seq_len: usize, + max_past: usize, + max_future: usize, + chunk_size: usize, + context_size: usize, + num_heads: usize, + head_dim: usize, + x: Var, +) -> Var { + let num_blocks = seq_len.div_ceil(chunk_size); + let x = pad_time_4d( + builder, + num_heads, + head_dim, + max_past, + max_future + chunk_size - 1, + x, + ); + let items = (0..num_blocks) + .map(|idx| { + let start = idx * chunk_size; + let window = slice(builder, 1, start, context_size, x.clone()); + unsqueeze::<4, 5>(builder, 1, window) + }) + .collect::>(); + items + .into_iter() + .reduce(|acc, item| concat(builder, 1, acc, item)) + .unwrap() +} + +fn blocked_bidirectional_attention_mask( + builder: &Builder, + seq_len: usize, + chunk_size: usize, + max_past: usize, + max_future: usize, + mask: Var, +) -> Var { + let num_blocks = seq_len.div_ceil(chunk_size); + let padded_seq_len = num_blocks * chunk_size; + let pad_amount = padded_seq_len - seq_len; + let valid_1d = cast( + builder, + eq( + builder, + mask, + constant(builder, 0.0, &shape!(builder, 1, seq_len)), + ), + Dtype::F32, + ); + let valid_1d = if pad_amount > 0 { + pad_time_2d(builder, 0, pad_amount, valid_1d) + } else { + valid_1d + }; + let query_valid = reshape( + builder, + shape!(builder, 1, 1, padded_seq_len, 1), + valid_1d.clone(), + ); + let query_valid = broadcast( + builder, + shape!(builder, 1, 1, padded_seq_len, padded_seq_len), + query_valid, + ); + let key_valid = reshape(builder, shape!(builder, 1, 1, 1, padded_seq_len), valid_1d); + let key_valid = broadcast( + builder, + shape!(builder, 1, 1, padded_seq_len, padded_seq_len), + key_valid, + ); + let sliding = sliding_window_valid_mask(builder, padded_seq_len, max_past, max_future); + let sliding = unsqueeze::<2, 3>(builder, 0, sliding); + let sliding = unsqueeze::<3, 4>(builder, 0, sliding); + let sliding = broadcast( + builder, + shape!(builder, 1, 1, padded_seq_len, padded_seq_len), + sliding, + ); + let valid_4d = query_valid * key_valid * sliding; + let valid_5d = convert_4d_mask_to_blocked_5d( + builder, + padded_seq_len, + chunk_size, + max_past, + max_future, + valid_4d, + ); + let one = constant(builder, 1.0, &shape(builder, valid_5d.clone())); + one - valid_5d +} + +fn convert_4d_mask_to_blocked_5d( + builder: &Builder, + padded_seq_len: usize, + chunk_size: usize, + max_past: usize, + max_future: usize, + mask_4d: Var, +) -> Var { + let num_blocks = padded_seq_len.div_ceil(chunk_size); + let context_size = chunk_size + max_past + max_future; + let mask_5d = reshape( + builder, + shape!(builder, 1, 1, num_blocks, chunk_size, padded_seq_len), + mask_4d, + ); + let mask_5d = pad_kv_5d(builder, max_past, max_future, mask_5d); + let items = (0..num_blocks) + .map(|block_idx| { + let start = block_idx * chunk_size; + let block = slice(builder, 2, block_idx, 1, mask_5d.clone()); + let block = squeeze::<5, 4>(builder, 2, block); + let block = slice(builder, 3, start, context_size, block); + unsqueeze::<4, 5>(builder, 2, block) + }) + .collect::>(); + items + .into_iter() + .reduce(|acc, item| concat(builder, 2, acc, item)) + .unwrap() +} + +fn pad_kv_5d(builder: &Builder, left: usize, right: usize, x: Var) -> Var { + let x_dtype = dtype(builder, x.clone()); + let [b, one, num_blocks, chunk_size, _seq_len] = + unpack::<5>(builder, shape(builder, x.clone())); + let left_pad = zeros( + builder, + &shape!(builder, b, one, num_blocks, chunk_size, left), + x_dtype.clone(), + ); + let right_pad = zeros( + builder, + &shape!(builder, b, one, num_blocks, chunk_size, right), + x_dtype, + ); + let x = concat(builder, 4, left_pad, x); + concat(builder, 4, x, right_pad) +} + +fn conv2d_stride2_square_no_bias( + builder: &Builder, + p: Path, + in_channels: usize, + out_channels: usize, + time_steps: usize, + freq_bins: usize, + x: Var, +) -> Var { + let out_time = time_steps.div_ceil(2); + let out_freq = freq_bins.div_ceil(2); + let x = pad_4d(builder, x, in_channels, time_steps, freq_bins, 1, 2, 1, 2); + + let mut windows = Vec::with_capacity(9); + for kernel_t in 0..3 { + let xt = slice(builder, 2, kernel_t, 2 * out_time, x.clone()); + let xt = reshape( + builder, + shape!(builder, 1, in_channels, out_time, 2, freq_bins + 3), + xt, + ); + let xt = squeeze::<5, 4>(builder, 3, slice(builder, 3, 0, 1, xt)); + for kernel_f in 0..3 { + let xf = slice(builder, 3, kernel_f, 2 * out_freq, xt.clone()); + let xf = reshape( + builder, + shape!(builder, 1, in_channels, out_time, out_freq, 2), + xf, + ); + let xf = squeeze::<5, 4>(builder, 4, slice(builder, 4, 0, 1, xf)); + let xf = transpose(builder, 1, 2, xf); + let xf = transpose(builder, 2, 3, xf); + windows.push(xf); + } + } + + let x = windows + .into_iter() + .map(|item| unsqueeze::<4, 5>(builder, 4, item)) + .reduce(|acc, item| concat(builder, 4, acc, item)) + .unwrap(); + let x = reshape( + builder, + shape!(builder, 1, out_time * out_freq, in_channels * 9), + x, + ); + let weight = param(builder, &p.extend(["weight"]).unwrap()); + let weight = reshape( + builder, + shape!(builder, out_channels, in_channels * 9), + weight, + ); + let weight = transpose(builder, 0, 1, weight); + let weight = broadcast( + builder, + shape!(builder, 1, in_channels * 9, out_channels), + weight, + ); + let x = matmul(builder, x, weight); + let x = reshape( + builder, + shape!(builder, 1, out_time, out_freq, out_channels), + x, + ); + let x = transpose(builder, 2, 3, x); + transpose(builder, 1, 2, x) +} + +fn pad_4d( + builder: &Builder, + x: Var, + channels: usize, + time_steps: usize, + freq_bins: usize, + top: usize, + bottom: usize, + left: usize, + right: usize, +) -> Var { + let time_pad = zeros( + builder, + &shape!(builder, 1, channels, top + time_steps + bottom, left), + dtype(builder, x.clone()), + ); + let time_pad_right = zeros( + builder, + &shape!(builder, 1, channels, top + time_steps + bottom, right), + dtype(builder, x.clone()), + ); + let top_pad = zeros( + builder, + &shape!(builder, 1, channels, top, freq_bins), + dtype(builder, x.clone()), + ); + let bottom_pad = zeros( + builder, + &shape!(builder, 1, channels, bottom, freq_bins), + dtype(builder, x.clone()), + ); + let x = concat(builder, 2, top_pad, x); + let x = concat(builder, 2, x, bottom_pad); + let x = concat(builder, 3, time_pad, x); + concat(builder, 3, x, time_pad_right) +} diff --git a/catgrad-llm/src/models/mod.rs b/catgrad-llm/src/models/mod.rs index 959bcb32..2a9ef447 100644 --- a/catgrad-llm/src/models/mod.rs +++ b/catgrad-llm/src/models/mod.rs @@ -1,3 +1,4 @@ +pub mod conformer; pub mod deepseek; pub mod gemma3; pub mod gemma4; From 2601a2dc8f3eabc42cae9fdc4dea1579f8cae5fd Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 15:25:22 +0300 Subject: [PATCH 2/3] gemma4: audio tower --- catgrad-llm/src/models/gemma4.rs | 211 ++++++++++++++++++++++++++----- catgrad-llm/src/utils/mod.rs | 5 + 2 files changed, 184 insertions(+), 32 deletions(-) diff --git a/catgrad-llm/src/models/gemma4.rs b/catgrad-llm/src/models/gemma4.rs index 252f938a..d7ec1d84 100644 --- a/catgrad-llm/src/models/gemma4.rs +++ b/catgrad-llm/src/models/gemma4.rs @@ -1,17 +1,21 @@ #![allow(clippy::too_many_arguments)] use crate::config::{EosTokenId, LLMConfig}; use crate::helpers::*; -use crate::utils::load_and_patchify_dynamic_image; +use crate::models::conformer::{Gemma4AudioConfig, Gemma4AudioTower}; +use crate::utils::{AUDIO_FEATURE_SIZE, load_and_patchify_dynamic_image, prepare_audio_features}; use catgrad::prelude::ops::*; use catgrad::prelude::*; use nn::*; use serde::{Deserialize, Serialize}; +use std::path::Path as FsPath; #[derive(Debug, Clone, Deserialize)] struct Gemma4Config { text_config: Gemma4TextConfig, vision_config: Option, + audio_config: Option, image_token_id: Option, + audio_token_id: Option, vision_soft_tokens_per_image: Option, eos_token_id: Option, } @@ -96,6 +100,21 @@ pub struct Gemma4PreparedImageInput { pub runtime_vision: Gemma4RuntimeVisionConfig, } +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct Gemma4RuntimeAudioConfig { + pub num_mel_frames: usize, + pub num_soft_tokens_per_audio: usize, +} + +#[derive(Debug, Clone)] +pub struct Gemma4PreparedAudioInput { + pub features: Vec, + pub shape: Vec, + pub mask: Vec, + pub mask_shape: Vec, + pub runtime_audio: Gemma4RuntimeAudioConfig, +} + impl Gemma4TextConfig { fn is_sliding_attention_layer(&self, layer_id: usize) -> bool { self.layer_types[layer_id] == "sliding_attention" @@ -222,6 +241,29 @@ pub fn prepare_gemma4_image_input( }) } +pub fn prepare_gemma4_audio_input( + audio_path: &FsPath, + config_json: &serde_json::Value, +) -> crate::Result { + let config: Gemma4Config = serde_json::from_value(config_json.clone())?; + let audio_config = config.audio_config.ok_or_else(|| { + crate::LLMError::InvalidModelConfig("gemma4 missing audio_config".to_string()) + })?; + let prepared = prepare_audio_features(audio_path)?; + let num_soft_tokens_per_audio = + audio_config.num_soft_tokens_for_frames(prepared.valid_mel_frames); + Ok(Gemma4PreparedAudioInput { + shape: prepared.feature_shape, + features: prepared.features, + mask_shape: prepared.mask_shape, + mask: prepared.mask, + runtime_audio: Gemma4RuntimeAudioConfig { + num_mel_frames: prepared.num_mel_frames, + num_soft_tokens_per_audio, + }, + }) +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Gemma4AttentionKind { Sliding, @@ -241,10 +283,17 @@ struct Gemma4LayerPlan { } #[derive(Debug, Clone)] -struct Gemma4MultimodalConfig { - vision_config: Gemma4VisionConfig, - image_token_index: usize, - runtime_vision: Gemma4RuntimeVisionConfig, +enum Gemma4MultimodalConfig { + Vision { + vision_config: Gemma4VisionConfig, + image_token_index: usize, + runtime_vision: Gemma4RuntimeVisionConfig, + }, + Audio { + audio_config: Gemma4AudioConfig, + audio_token_index: usize, + runtime_audio: Gemma4RuntimeAudioConfig, + }, } #[derive(Debug, Clone)] @@ -315,26 +364,57 @@ impl LLMModel for Gemma4Model { fn multimodal_metadata(&self) -> Option { let mm = self.multimodal.as_ref()?; - Some(MultimodalMetadata { - image_token_index: mm.image_token_index, - mm_tokens_per_image: mm.runtime_vision.num_soft_tokens_per_image, - hidden_size: self.config.hidden_size, - image_size: mm - .runtime_vision - .patch_grid_height - .max(mm.runtime_vision.patch_grid_width) - * mm.vision_config.patch_size, - patch_size: mm.vision_config.patch_size, - }) + match mm { + Gemma4MultimodalConfig::Vision { + vision_config, + image_token_index, + runtime_vision, + } => Some(MultimodalMetadata { + image_token_index: *image_token_index, + mm_tokens_per_image: runtime_vision.num_soft_tokens_per_image, + hidden_size: self.config.hidden_size, + image_size: runtime_vision + .patch_grid_height + .max(runtime_vision.patch_grid_width) + * vision_config.patch_size, + patch_size: vision_config.patch_size, + }), + Gemma4MultimodalConfig::Audio { + audio_token_index, + runtime_audio, + .. + } => Some(MultimodalMetadata { + image_token_index: *audio_token_index, + mm_tokens_per_image: runtime_audio.num_soft_tokens_per_audio, + hidden_size: self.config.hidden_size, + image_size: 0, + patch_size: 0, + }), + } } fn multimodal_vision_module(&self) -> Option> { let mm = self.multimodal.as_ref()?; - Some(Box::new(Gemma4VisionEmbeddings { - vision_config: mm.vision_config.clone(), - runtime_vision: mm.runtime_vision.clone(), - text_hidden_size: self.config.hidden_size, - })) + match mm { + Gemma4MultimodalConfig::Vision { + vision_config, + runtime_vision, + .. + } => Some(Box::new(Gemma4VisionEmbeddings { + vision_config: vision_config.clone(), + runtime_vision: runtime_vision.clone(), + text_hidden_size: self.config.hidden_size, + })), + Gemma4MultimodalConfig::Audio { + audio_config, + runtime_audio, + .. + } => Some(Box::new(Gemma4AudioEmbeddings { + audio_config: audio_config.clone(), + runtime_audio: runtime_audio.clone(), + text_hidden_size: self.config.hidden_size, + })), + } } fn multimodal_language_module(&self) -> Option> { @@ -346,15 +426,26 @@ impl LLMModel for Gemma4Model { fn multimodal_interpolate_prompt(&self, prompt: &str) -> Option { let mm = self.multimodal.as_ref()?; - Some(prompt.replace( - "<|image|>", - &format!( - "{}{}{}", - "<|image>", - "<|image|>".repeat(mm.runtime_vision.num_soft_tokens_per_image), - "" - ), - )) + match mm { + Gemma4MultimodalConfig::Vision { runtime_vision, .. } => Some(prompt.replace( + "<|image|>", + &format!( + "{}{}{}", + "<|image>", + "<|image|>".repeat(runtime_vision.num_soft_tokens_per_image), + "" + ), + )), + Gemma4MultimodalConfig::Audio { runtime_audio, .. } => Some(prompt.replace( + "<|audio|>", + &format!( + "{}{}{}", + "<|audio>", + "<|audio|>".repeat(runtime_audio.num_soft_tokens_per_audio), + "" + ), + )), + } } } @@ -363,12 +454,15 @@ impl Gemma4Model { root: &str, config_json: &serde_json::Value, runtime_vision: Option<&Gemma4RuntimeVisionConfig>, + runtime_audio: Option<&Gemma4RuntimeAudioConfig>, dtype: Dtype, ) -> crate::Result { let Gemma4Config { mut text_config, vision_config, + audio_config, image_token_id, + audio_token_id, eos_token_id, .. }: Gemma4Config = serde_json::from_value(config_json.clone())?; @@ -379,13 +473,22 @@ impl Gemma4Model { let multimodal = match (vision_config, image_token_id, runtime_vision) { (Some(vision_config), Some(image_token_index), Some(runtime_vision)) => { - Some(Gemma4MultimodalConfig { + Some(Gemma4MultimodalConfig::Vision { vision_config, image_token_index, runtime_vision: runtime_vision.clone(), }) } - _ => None, + _ => match (audio_config, audio_token_id, runtime_audio) { + (Some(audio_config), Some(audio_token_index), Some(runtime_audio)) => { + Some(Gemma4MultimodalConfig::Audio { + audio_config, + audio_token_index, + runtime_audio: runtime_audio.clone(), + }) + } + _ => None, + }, }; let first_shared_layer = text_config.num_hidden_layers - text_config.num_kv_shared_layers; @@ -1143,6 +1246,50 @@ fn gemma4_apply_2d_rope( concat(builder, 3, x_part, y_part) } +pub struct Gemma4AudioEmbeddings { + audio_config: Gemma4AudioConfig, + runtime_audio: Gemma4RuntimeAudioConfig, + text_hidden_size: usize, +} + +impl DynModule for Gemma4AudioEmbeddings { + fn path(&self) -> Path { + path(vec!["Gemma4AudioEmbeddings"]).unwrap() + } + + fn ty(&self) -> (Vec, Vec) { + use catgrad::typecheck::TypeExpr; + let t = Type::Tensor(TypeExpr::Var(0)); + (vec![t.clone(), t.clone()], vec![t]) + } + + fn def(&self, builder: &Builder, args: Vec) -> Vec { + let [features, mask]: [Var; 2] = args.try_into().expect("expected 2 inputs"); + let tower = Gemma4AudioTower { + config: self.audio_config.clone(), + input_time_steps: self.runtime_audio.num_mel_frames, + input_feature_bins: AUDIO_FEATURE_SIZE, + }; + let x = tower.audio_model(builder, features, mask); + let x = slice( + builder, + 1, + 0, + self.runtime_audio.num_soft_tokens_per_audio, + x, + ); + let x = rmsnorm_raw::<3>(builder, self.audio_config.rms_norm_eps, x); + let x = linear_no_bias( + builder, + self.audio_config.output_proj_dims, + self.text_hidden_size, + path(vec!["model", "embed_audio", "embedding_projection"]).unwrap(), + x, + ); + vec![x] + } +} + pub struct Gemma4VisionEmbeddings { vision_config: Gemma4VisionConfig, runtime_vision: Gemma4RuntimeVisionConfig, diff --git a/catgrad-llm/src/utils/mod.rs b/catgrad-llm/src/utils/mod.rs index 9641b06d..eb0f2e99 100644 --- a/catgrad-llm/src/utils/mod.rs +++ b/catgrad-llm/src/utils/mod.rs @@ -145,6 +145,7 @@ pub struct PreparedImageInput { pub enum ModelRuntimeContext { Qwen3_5Vision(models::qwen3_5::Qwen3_5RuntimeVisionConfig), Gemma4Vision(models::gemma4::Gemma4RuntimeVisionConfig), + Gemma4Audio(models::gemma4::Gemma4RuntimeAudioConfig), Lfm2Vision(models::lfm2::Lfm2RuntimeVisionConfig), } @@ -321,6 +322,10 @@ pub fn get_model( Some(ModelRuntimeContext::Gemma4Vision(runtime_vision)) => Some(runtime_vision), _ => None, }, + match runtime_context { + Some(ModelRuntimeContext::Gemma4Audio(runtime_audio)) => Some(runtime_audio), + _ => None, + }, dtype, )?), "Mistral3ForConditionalGeneration" => Box::new(models::mistral3::Mistral3Model::new( From 3e55ae6a32abeff8a25854a14dcfce99caa1924c Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Wed, 29 Apr 2026 16:52:25 +0300 Subject: [PATCH 3/3] example app: Add audio input processing --- catgrad-llm/examples/llama/main.rs | 250 ++++++++++++++++++----------- catgrad-llm/src/run.rs | 2 +- catgrad-llm/src/utils/mod.rs | 20 +-- 3 files changed, 171 insertions(+), 101 deletions(-) diff --git a/catgrad-llm/examples/llama/main.rs b/catgrad-llm/examples/llama/main.rs index 029f8701..8392f903 100644 --- a/catgrad-llm/examples/llama/main.rs +++ b/catgrad-llm/examples/llama/main.rs @@ -3,6 +3,7 @@ use catgrad::interpreter::backend::candle::CandleBackend; use catgrad::interpreter::backend::ndarray::NdArrayBackend; use catgrad::prelude::*; use catgrad_llm::helpers::{LLMModel, ToolCall, ToolUseStep}; +use catgrad_llm::models; use catgrad_llm::utils::*; use clap::{Parser, ValueEnum}; use minijinja::{Value, context}; @@ -38,6 +39,9 @@ struct Args { /// Optional image input for multimodal-capable models #[arg(short = 'i', long)] image: Option, + /// Optional audio input for Gemma4 audio-capable models + #[arg(short = 'a', long)] + audio: Option, /// Pass raw prompt without chat template #[arg(long)] raw: bool, @@ -164,6 +168,12 @@ fn main() -> Result<()> { if args.tool_use && args.bench.is_some() { anyhow::bail!("--tool-use does not support --bench"); } + if args.image.is_some() && args.audio.is_some() { + anyhow::bail!("--image and --audio are mutually exclusive"); + } + if args.tool_use && (args.image.is_some() || args.audio.is_some()) { + anyhow::bail!("--tool-use does not support multimodal inputs"); + } let app_config = get_app_config(&args)?; if args.list_models { @@ -198,15 +208,29 @@ fn get_models(app_config: &AppConfig) -> Vec<(&str, &str)> { models } -fn user_message(prompt: &str, has_image: bool) -> Value { - if has_image { - let content = vec![ - context!(type => "text", text => prompt), - context!(type => "image"), - ]; - context!(role => "user", content => content) - } else { - context!(role => "user", content => prompt) +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum PromptModality { + Image, + Audio, +} + +fn user_message(prompt: &str, modality: Option) -> Value { + match modality { + Some(PromptModality::Image) => { + let content = vec![ + context!(type => "text", text => prompt), + context!(type => "image"), + ]; + context!(role => "user", content => content) + } + Some(PromptModality::Audio) => { + let content = vec![ + context!(type => "text", text => prompt), + context!(type => "audio"), + ]; + context!(role => "user", content => content) + } + None => context!(role => "user", content => prompt), } } @@ -255,7 +279,7 @@ fn render_tool_prompt( Ok(render_chat_template_values( chat_template, tokenizer_config, - &[user_message(prompt, false)], + &[user_message(prompt, None)], RenderChatTemplateOptions { enable_thinking, tools: Some(tools), @@ -284,7 +308,7 @@ fn render_tool_follow_up_prompt( ) }; - let mut messages = vec![user_message(prompt, false), assistant_message]; + let mut messages = vec![user_message(prompt, None), assistant_message]; messages.extend( tool_use_step .tool_calls @@ -347,7 +371,15 @@ fn run_with_backend( let mut pp = 0; let mut tg = 0; let mut max_seq_len = args.max_seq_len; - let use_image = args.image.is_some() && !benchmarking; + let modality = if benchmarking { + None + } else if args.image.is_some() { + Some(PromptModality::Image) + } else if args.audio.is_some() { + Some(PromptModality::Audio) + } else { + None + }; let prompt = if let Some(bench) = &args.bench { pp = bench[0]; @@ -369,27 +401,42 @@ fn run_with_backend( } else if chat_template.is_empty() || args.raw { args.prompt.clone() } else { - render_chat_template( + render_chat_template_values( &chat_template, &tokenizer_config, - &args.prompt, - use_image, - args.thinking, + &[user_message(&args.prompt, modality)], + RenderChatTemplateOptions { + enable_thinking: args.thinking, + tools: None, + }, )? }; - let prepared_multimodal = if use_image { - prepare_multimodal_input(&config_json, args.image.as_deref())? - } else { - Default::default() + let (prepared_image, prepared_audio, runtime_context) = match modality { + Some(PromptModality::Image) => { + let prepared = prepare_multimodal_input(&config_json, args.image.as_deref())?; + (prepared.image, None, prepared.runtime_context) + } + Some(PromptModality::Audio) => { + let audio_path = args + .audio + .as_ref() + .expect("audio existence already checked"); + let prepared = models::gemma4::prepare_gemma4_audio_input(audio_path, &config_json)?; + ( + None, + Some(prepared.clone()), + Some(ModelRuntimeContext::Gemma4Audio(prepared.runtime_audio)), + ) + } + None => (None, None, None), }; - let runtime_context = prepared_multimodal.runtime_context.as_ref(); - if !benchmarking && !use_image && !args.tool_use { + if !benchmarking && modality.is_none() && !args.tool_use { print!("{prompt}"); } - let prompt = if use_image { - interpolate_multimodal_prompt(&config_json, runtime_context, &prompt)? + let prompt = if modality.is_some() { + interpolate_multimodal_prompt(&config_json, runtime_context.as_ref(), &prompt)? } else { prompt }; @@ -423,7 +470,7 @@ fn run_with_backend( let model = get_model( &config_json, max_sequence_length, - runtime_context, + runtime_context.as_ref(), model_dtype, )?; post_process_model_weights( @@ -433,7 +480,7 @@ fn run_with_backend( &mut parameter_types, )?; - let mm_metadata = if use_image { + let mm_metadata = if modality.is_some() { Some( model .multimodal_metadata() @@ -446,7 +493,7 @@ fn run_with_backend( let typed_term = if let Some(load_path) = &args.load { let file = std::fs::File::open(load_path)?; serde_json::from_reader(file)? - } else if use_image { + } else if modality.is_some() { let language_model = model.multimodal_language_module().ok_or_else(|| { anyhow::anyhow!( "Model {} does not provide multimodal language module", @@ -473,7 +520,7 @@ fn run_with_backend( // Get stdlib environment and extend with parameter declarations let mut env = stdlib(); - let load_prefix = if use_image { + let load_prefix = if modality.is_some() { catgrad::prelude::Path::empty() } else { model.path() @@ -490,66 +537,89 @@ fn run_with_backend( let mut multimodal_ctx: Option> = None; if let Some(mm) = mm_metadata { - let vision_model = model.multimodal_vision_module().ok_or_else(|| { - anyhow::anyhow!("Model {} does not provide vision module", model_name) - })?; - let image_path = args - .image - .as_ref() - .expect("image existence already checked"); - let prepared_image = prepared_multimodal.image.as_ref().ok_or_else(|| { + let encoder_model = model.multimodal_vision_module().ok_or_else(|| { anyhow::anyhow!( - "Model {} did not provide prepared image input for {}", - model_name, - image_path.display() + "Model {} does not provide multimodal encoder module", + model_name ) })?; - let image_data = prepared_image.data.clone(); - let image_shape = prepared_image.shape.clone(); - let cache_path = - cache_path_for_embeddings(&model_name, &image_path.to_string_lossy(), &image_data); - let visual_embeddings = if let Ok(cached) = load_cached_embeddings(&cache_path) { - eprintln!( - "Loading cached image features from: {}", - cache_path.display() - ); - interpreter::float_tensor( + let modality_embeddings = if let Some(prepared_image) = prepared_image.as_ref() { + let image_path = args + .image + .as_ref() + .expect("image existence already checked"); + let image_data = prepared_image.data.clone(); + let image_shape = prepared_image.shape.clone(); + let cache_path = + cache_path_for_embeddings(&model_name, &image_path.to_string_lossy(), &image_data); + if let Ok(cached) = load_cached_embeddings(&cache_path) { + eprintln!( + "Loading cached image features from: {}", + cache_path.display() + ); + interpreter::float_tensor( + &interpreter.backend, + Shape(vec![1, mm.mm_tokens_per_image, mm.hidden_size]), + cached, + model_dtype, + ) + .map_err(|e| anyhow::anyhow!("BackendError: {:?}", e))? + } else { + let image_tensor = interpreter::float_tensor( + &interpreter.backend, + Shape(image_shape), + image_data, + model_dtype, + ) + .map_err(|e| anyhow::anyhow!("BackendError: {:?}", e))?; + let encoder_term = encoder_model + .term() + .ok_or_else(|| anyhow::anyhow!("failed to build multimodal encoder term"))?; + let results = interpreter.run(encoder_term.term, vec![image_tensor])?; + let embeddings = results + .first() + .cloned() + .ok_or_else(|| anyhow::anyhow!("multimodal encoder returned no outputs"))?; + let flattened = to_f32_vec(&interpreter.backend, &embeddings)?; + save_cached_embeddings(&cache_path, &flattened)?; + eprintln!("Saved image features to: {}", cache_path.display()); + embeddings + } + } else { + let prepared_audio = prepared_audio.as_ref().ok_or_else(|| { + anyhow::anyhow!("Model {} did not provide prepared audio input", model_name) + })?; + let audio_tensor = interpreter::float_tensor( &interpreter.backend, - Shape(vec![1, mm.mm_tokens_per_image, mm.hidden_size]), - cached, + Shape(prepared_audio.shape.clone()), + prepared_audio.features.clone(), model_dtype, ) - .map_err(|e| anyhow::anyhow!("BackendError: {:?}", e))? - } else { - let image_tensor = interpreter::float_tensor( + .map_err(|e| anyhow::anyhow!("BackendError: {:?}", e))?; + let mask_tensor = interpreter::float_tensor( &interpreter.backend, - Shape(image_shape), - image_data, + Shape(prepared_audio.mask_shape.clone()), + prepared_audio.mask.clone(), model_dtype, ) .map_err(|e| anyhow::anyhow!("BackendError: {:?}", e))?; - let vision_term = vision_model + let encoder_term = encoder_model .term() - .ok_or_else(|| anyhow::anyhow!("failed to build vision model term"))?; - let results = interpreter.run(vision_term.term, vec![image_tensor])?; - let visual_embeddings = results + .ok_or_else(|| anyhow::anyhow!("failed to build multimodal encoder term"))?; + let results = interpreter.run(encoder_term.term, vec![audio_tensor, mask_tensor])?; + results .first() .cloned() - .ok_or_else(|| anyhow::anyhow!("Vision model returned no outputs"))?; - let flattened = to_f32_vec(&interpreter.backend, &visual_embeddings)?; - save_cached_embeddings(&cache_path, &flattened)?; - eprintln!("Saved image features to: {}", cache_path.display()); - visual_embeddings + .ok_or_else(|| anyhow::anyhow!("multimodal encoder returned no outputs"))? }; - multimodal_ctx = Some(MultimodalRuntime { hidden_size: mm.hidden_size, - image_token_index: mm.image_token_index, - visual_embeddings, + placeholder_token_index: mm.image_token_index, + modality_embeddings, }); } - let use_kv_cache = args.kv_cache || use_image; + let use_kv_cache = args.kv_cache || modality.is_some(); if args.tool_use { let (first_text, first_tokens, first_elapsed_pp, first_elapsed_gen) = generate_stream( model.as_ref(), @@ -690,8 +760,8 @@ fn to_f32_vec( struct MultimodalRuntime { hidden_size: usize, - image_token_index: usize, - visual_embeddings: interpreter::Value, + placeholder_token_index: usize, + modality_embeddings: interpreter::Value, } enum DecodeInputs<'a, B: interpreter::Backend> { @@ -701,9 +771,9 @@ enum DecodeInputs<'a, B: interpreter::Backend> { Multimodal { input_tokens: &'a [u32], hidden_size: usize, - image_token_index: usize, - visual_embeddings: &'a interpreter::Value, - use_image_embeddings: bool, + placeholder_token_index: usize, + modality_embeddings: &'a interpreter::Value, + use_modality_embeddings: bool, }, } @@ -726,7 +796,7 @@ fn generate_stream( ) -> Result<(String, usize, std::time::Duration, std::time::Duration)> { let eos_token_ids = model.config().get_eos_token_ids(); let mut state_cache = empty_state_cache(&interpreter.backend, model)?; - let mut use_image_embeddings = config.multimodal_ctx.is_some(); + let mut use_modality_embeddings = config.multimodal_ctx.is_some(); let mut output = String::new(); let mut generated_tokens = 0; let mut start_gen = std::time::Instant::now(); @@ -737,9 +807,9 @@ fn generate_stream( DecodeInputs::Multimodal { input_tokens: &token_ids, hidden_size: ctx.hidden_size, - image_token_index: ctx.image_token_index, - visual_embeddings: &ctx.visual_embeddings, - use_image_embeddings, + placeholder_token_index: ctx.placeholder_token_index, + modality_embeddings: &ctx.modality_embeddings, + use_modality_embeddings, } } else { DecodeInputs::Text { @@ -769,7 +839,7 @@ fn generate_stream( token_ids.push(next_token_id); } if config.multimodal_ctx.is_some() && config.use_kv_cache { - use_image_embeddings = false; + use_modality_embeddings = false; } let decoded_token = tokenizer.decode(&[next_token_id], true).unwrap(); output.push_str(&decoded_token); @@ -814,35 +884,35 @@ fn run_interpreter( DecodeInputs::Multimodal { input_tokens, hidden_size, - image_token_index, - visual_embeddings, - use_image_embeddings, + placeholder_token_index, + modality_embeddings, + use_modality_embeddings, } => { input_seq_len = input_tokens.len(); - let empty_image_embeddings = interpreter::float_tensor( + let empty_modality_embeddings = interpreter::float_tensor( &interpreter.backend, Shape(vec![1, 0, hidden_size]), Vec::::new(), model.dtype(), ) - .map_err(|err| anyhow::anyhow!("empty image tensor error: {:?}", err))?; + .map_err(|err| anyhow::anyhow!("empty modality tensor error: {:?}", err))?; - let (text_before_tokens, text_after_tokens) = if use_image_embeddings { - split_image_tokens(input_tokens, image_token_index)? + let (text_before_tokens, text_after_tokens) = if use_modality_embeddings { + split_placeholder_tokens(input_tokens, placeholder_token_index)? } else { (&[][..], input_tokens) }; let text_before = token_tensor(interpreter, "text_before", text_before_tokens)?; let text_after = token_tensor(interpreter, "text_after", text_after_tokens)?; - let image_embeddings = if use_image_embeddings { - visual_embeddings.clone() + let modality_embeddings = if use_modality_embeddings { + modality_embeddings.clone() } else { - empty_image_embeddings + empty_modality_embeddings }; inputs.push(text_before); - inputs.push(image_embeddings); + inputs.push(modality_embeddings); inputs.push(text_after); } } diff --git a/catgrad-llm/src/run.rs b/catgrad-llm/src/run.rs index 182255da..75907563 100644 --- a/catgrad-llm/src/run.rs +++ b/catgrad-llm/src/run.rs @@ -410,7 +410,7 @@ impl ModelRunner { let mut inputs = Vec::with_capacity(self.state_cache.len() + 5); if let Some(multimodal) = self.multimodal.as_mut() { let (text_before_tokens, text_after_tokens) = if multimodal.use_image_embeddings { - split_image_tokens(tokens, multimodal.image_token_index)? + split_placeholder_tokens(tokens, multimodal.image_token_index)? } else { (&[][..], tokens) }; diff --git a/catgrad-llm/src/utils/mod.rs b/catgrad-llm/src/utils/mod.rs index eb0f2e99..0cff93b3 100644 --- a/catgrad-llm/src/utils/mod.rs +++ b/catgrad-llm/src/utils/mod.rs @@ -274,28 +274,28 @@ pub fn interpolate_multimodal_prompt( ))) } -/// Split a multimodal token sequence into the text before and after the image-token span. -pub fn split_image_tokens( +/// Split a multimodal token sequence into the text before and after the placeholder-token span. +pub fn split_placeholder_tokens( input_tokens: &[u32], - image_token_index: usize, + mm_token_index: usize, ) -> Result<(&[u32], &[u32])> { - let image_token = image_token_index as u32; + let mm_token = mm_token_index as u32; let first_image_token_index = input_tokens .iter() - .position(|&token| token == image_token) + .position(|&token| token == mm_token) .ok_or_else(|| { LLMError::InvalidModelConfig(format!( - "multimodal prompt is missing image token {image_token_index}" + "multimodal prompt is missing image or audio token {mm_token_index}" )) })?; - let last_image_token_index = input_tokens + let last_mm_token_index = input_tokens .iter() - .rposition(|&token| token == image_token) - .expect("first image token implies last image token"); + .rposition(|&token| token == mm_token) + .expect("mm token not found when searching for last occurence"); Ok(( &input_tokens[..first_image_token_index], - &input_tokens[last_image_token_index + 1..], + &input_tokens[last_mm_token_index + 1..], )) }