diff --git a/crates/vad/src/continuous.rs b/crates/vad/src/continuous.rs index c0240ba65..386e5434e 100644 --- a/crates/vad/src/continuous.rs +++ b/crates/vad/src/continuous.rs @@ -7,7 +7,7 @@ use std::{ use futures_util::{future, Stream, StreamExt}; use kalosm_sound::AsyncSource; -use silero_rs::{VadConfig, VadSession, VadTransition}; +pub use silero_rs::{VadConfig, VadSession, VadTransition}; #[derive(Debug, Clone)] pub enum VadStreamItem { diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index c711ef6b6..d116b9872 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -68,6 +68,8 @@ struct AudioChannels { process_mic_rx: flume::Receiver>, process_speaker_tx: flume::Sender>, process_speaker_rx: flume::Receiver>, + control_tx: flume::Sender, + control_rx: flume::Receiver, } impl AudioChannels { @@ -81,6 +83,8 @@ impl AudioChannels { let (process_speaker_tx, process_speaker_rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); + let (control_tx, control_rx) = flume::bounded::(8); + let (save_mic_raw_tx, save_mic_raw_rx) = if cfg!(debug_assertions) { let (tx, rx) = flume::bounded::>(CHUNK_BUFFER_SIZE); (Some(tx), Some(rx)) @@ -110,6 +114,8 @@ impl AudioChannels { process_mic_rx, process_speaker_tx, process_speaker_rx, + control_tx, + control_rx, } } @@ -293,6 +299,7 @@ impl Session { let save_speaker_raw_tx = channels.save_speaker_raw_tx.clone(); let process_mic_tx = channels.process_mic_tx.clone(); let process_speaker_tx = channels.process_speaker_tx.clone(); + let control_tx = channels.control_tx.clone(); async move { let mut aec = hypr_aec::AEC::new().unwrap(); @@ -300,6 +307,13 @@ impl Session { let mut speaker_agc = hypr_agc::Agc::default(); let mut last_broadcast = Instant::now(); + let mut vad = hypr_vad::VadSession::new(hypr_vad::VadConfig { + redemption_time: Duration::from_millis(70), + pre_speech_pad: Duration::from_millis(70), + ..Default::default() + }) + .unwrap(); + loop { let (mut mic_chunk_raw, mut speaker_chunk): (Vec, Vec) = match tokio::join!(mic_rx.recv_async(), speaker_rx.recv_async()) { @@ -325,6 +339,21 @@ impl Session { let _ = rx.changed().await; continue; } + let mixed: Vec = mic_chunk + .iter() + .zip(speaker_chunk.iter()) + .map(|(mic, speaker)| (mic + speaker).clamp(-1.0, 1.0)) + .collect(); + + if let Ok(transitions) = vad.process(&mixed) { + for transition in transitions { + if let hypr_vad::VadTransition::SpeechEnd { .. } = transition { + let _ = control_tx + .send_async(owhisper_interface::ControlMessage::Finalize) + .await; + } + } + } let processed_mic = mic_chunk.clone(); let processed_speaker = speaker_chunk.clone(); @@ -442,12 +471,18 @@ impl Session { .into_stream() .map(|v| hypr_audio_utils::f32_to_i16_bytes(v.into_iter())); - let combined_audio_stream = - mic_audio_stream - .zip(speaker_audio_stream) - .map(|(mic, speaker)| { - owhisper_interface::MixedMessage::Audio((mic.into(), speaker.into())) - }); + let audio_stream = mic_audio_stream + .zip(speaker_audio_stream) + .map(|(mic, speaker)| { + owhisper_interface::MixedMessage::Audio((mic.into(), speaker.into())) + }); + + let control_stream = channels + .control_rx + .into_stream() + .map(|control| owhisper_interface::MixedMessage::Control(control)); + + let combined_audio_stream = futures_util::stream::select(audio_stream, control_stream); tasks.spawn({ let app = self.app.clone(); diff --git a/plugins/listener/src/manager.rs b/plugins/listener/src/manager.rs index 53ac4c293..f29d8130a 100644 --- a/plugins/listener/src/manager.rs +++ b/plugins/listener/src/manager.rs @@ -2,6 +2,7 @@ pub struct TranscriptManager { id: uuid::Uuid, partial_words: Vec, + final_words: Vec, } impl Default for TranscriptManager { @@ -9,6 +10,7 @@ impl Default for TranscriptManager { Self { id: uuid::Uuid::new_v4(), partial_words: Vec::new(), + final_words: Vec::new(), } } } @@ -93,16 +95,35 @@ impl TranscriptManager { }; if is_final { - let last_final_word_end = words.last().unwrap().end; - self.partial_words = self - .partial_words - .iter() - .filter(|w| w.start > last_final_word_end) - .cloned() - .collect::>(); + // TODO: maybe we should replace. + let new_final_words: Vec = words + .into_iter() + .filter(|new_word| { + !self.final_words.iter().any(|existing_word| { + (existing_word.start - new_word.start).abs() < 0.001 + && (existing_word.end - new_word.end).abs() < 0.001 + }) + }) + .collect(); + + self.final_words.extend(new_final_words.clone()); + self.final_words.sort_by(|a, b| { + a.start + .partial_cmp(&b.start) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + if let Some(last_final_word) = new_final_words.last() { + self.partial_words = self + .partial_words + .iter() + .filter(|w| w.start > last_final_word.end) + .cloned() + .collect::>(); + } return Diff { - final_words: words.clone(), + final_words: new_final_words, partial_words: self.partial_words.clone(), }; } else if data.confidence > 0.6 {