diff --git a/.github/workflows/desktop_ci.yaml b/.github/workflows/desktop_ci.yaml index 0706efefcf..61aa5c75cd 100644 --- a/.github/workflows/desktop_ci.yaml +++ b/.github/workflows/desktop_ci.yaml @@ -143,7 +143,9 @@ jobs: --exclude tauri-plugin-webhook \ --exclude tauri-plugin-windows \ --exclude db3 \ - --exclude db-core2 + --exclude db-core2 \ + --exclude activity-capture-macos \ + --exclude tauri-plugin-activity-capture ci: if: always() diff --git a/apps/cli/src/app.rs b/apps/cli/src/app.rs index aee89b67de..fe7a5c73af 100644 --- a/apps/cli/src/app.rs +++ b/apps/cli/src/app.rs @@ -96,10 +96,10 @@ impl AppContext { fn analytics_client() -> hypr_analytics::AnalyticsClient { let mut builder = hypr_analytics::AnalyticsClientBuilder::default(); - if std::env::var_os("DO_NOT_TRACK").is_none() { - if let Some(key) = option_env!("POSTHOG_API_KEY") { - builder = builder.with_posthog(key); - } + if std::env::var_os("DO_NOT_TRACK").is_none() + && let Some(key) = option_env!("POSTHOG_API_KEY") + { + builder = builder.with_posthog(key); } builder.build() } diff --git a/apps/cli/src/cli.rs b/apps/cli/src/cli.rs index 2e12bc95e7..9c3906b04b 100644 --- a/apps/cli/src/cli.rs +++ b/apps/cli/src/cli.rs @@ -7,7 +7,7 @@ use clap_verbosity_flag::{InfoLevel, Verbosity}; name = "char", version, propagate_version = true, - after_help = "Docs: https://char.com/docs/cli\nBugs: https://github.com/fastrepl/char/issues" + after_help = "Docs: https://cli.char.com\nDiscussions: https://github.com/fastrepl/char/discussions/4788\nBugs: https://github.com/fastrepl/char/issues" )] pub struct Cli { #[command(subcommand)] diff --git a/apps/cli/src/commands/export/helpers.rs b/apps/cli/src/commands/export/helpers.rs index 0b55f2d6c3..6ca1bfb945 100644 --- a/apps/cli/src/commands/export/helpers.rs +++ b/apps/cli/src/commands/export/helpers.rs @@ -56,13 +56,13 @@ pub(super) fn build_segments( }) .unwrap_or_else(|| format!("Channel {}", word.channel)); - if let Some(last) = segments.last_mut() { - if last.speaker == speaker { - last.text.push(' '); - last.text.push_str(&word.text); - last.end_ms = word.end_ms; - continue; - } + if let Some(last) = segments.last_mut() + && last.speaker == speaker + { + last.text.push(' '); + last.text.push_str(&word.text); + last.end_ms = word.end_ms; + continue; } segments.push(Segment { diff --git a/apps/cli/src/commands/model/list.rs b/apps/cli/src/commands/model/list.rs index fe483a1def..877847d4a9 100644 --- a/apps/cli/src/commands/model/list.rs +++ b/apps/cli/src/commands/model/list.rs @@ -53,7 +53,6 @@ pub(crate) async fn collect_model_rows( rows.sort_by(|a, b| { status_rank(&a.status) .cmp(&status_rank(&b.status)) - .then_with(|| a.kind.cmp(&b.kind)) .then_with(|| a.name.cmp(&b.name)) }); rows @@ -81,7 +80,7 @@ pub(super) async fn write_model_output( .load_preset(UTF8_FULL_CONDENSED) .set_content_arrangement(ContentArrangement::Dynamic); - table.set_header(vec!["Name", "Type", "Status", "Title", "Details", "Path"]); + table.set_header(vec!["Name", "Status", "Path"]); for row in rows { let path = match &home { @@ -92,14 +91,7 @@ pub(super) async fn write_model_output( .unwrap_or_else(|| row.install_path.clone()), None => row.install_path.clone(), }; - table.add_row(vec![ - row.name.clone(), - row.kind.clone(), - row.status.clone(), - row.display_name.clone(), - detail_text(row).to_string(), - path, - ]); + table.add_row(vec![row.name.clone(), row.status.clone(), path]); } println!("{table}"); @@ -108,14 +100,6 @@ pub(super) async fn write_model_output( Ok(()) } -fn detail_text(row: &ModelRow) -> &str { - if row.description.is_empty() { - "-" - } else { - row.description.as_str() - } -} - fn status_rank(status: &str) -> usize { match status { "downloaded" => 0, @@ -154,7 +138,6 @@ mod tests { rows.sort_by(|a, b| { status_rank(&a.status) .cmp(&status_rank(&b.status)) - .then_with(|| a.kind.cmp(&b.kind)) .then_with(|| a.name.cmp(&b.name)) }); diff --git a/apps/cli/src/commands/model/mod.rs b/apps/cli/src/commands/model/mod.rs index d879fe4f7a..4ca9d7c072 100644 --- a/apps/cli/src/commands/model/mod.rs +++ b/apps/cli/src/commands/model/mod.rs @@ -125,6 +125,9 @@ fn make_manager( } pub(crate) fn model_is_enabled(model: &LocalModel) -> bool { + if matches!(model, LocalModel::Am(_)) { + return false; + } cfg!(all( target_os = "macos", any(target_arch = "arm", target_arch = "aarch64") diff --git a/apps/cli/src/commands/record/mod.rs b/apps/cli/src/commands/record/mod.rs index cd600c3900..99b51efa51 100644 --- a/apps/cli/src/commands/record/mod.rs +++ b/apps/cli/src/commands/record/mod.rs @@ -131,22 +131,22 @@ async fn run_with_audio( let capture = runtime::capture(audio, args.audio, sample_rate, chunk_size, |progress| { app.update(&progress); - if progress.emit_event { - if let Some(writer) = event_writer.as_mut() { - writer.emit(&RecordEvent::Progress { - elapsed_ms: progress.elapsed.as_millis() as u64, - audio_secs: progress.audio_secs, - sample_count: progress.sample_count, - level_left: progress.left_level, - level_right: progress.right_level, - })?; - } + if progress.emit_event + && let Some(writer) = event_writer.as_mut() + { + writer.emit(&RecordEvent::Progress { + elapsed_ms: progress.elapsed.as_millis() as u64, + audio_secs: progress.audio_secs, + sample_count: progress.sample_count, + level_left: progress.left_level, + level_right: progress.right_level, + })?; } - if progress.render_ui { - if let Some(view) = viewport.as_mut() { - view.poll_input(); - view.draw(&app.lines()); - } + if progress.render_ui + && let Some(view) = viewport.as_mut() + { + view.poll_input(); + view.draw(&app.lines()); } Ok(()) }) diff --git a/apps/cli/src/commands/transcribe/mod.rs b/apps/cli/src/commands/transcribe/mod.rs index da0c3f03d3..c52decbb1d 100644 --- a/apps/cli/src/commands/transcribe/mod.rs +++ b/apps/cli/src/commands/transcribe/mod.rs @@ -1,5 +1,4 @@ mod output; -mod response; use std::io::BufWriter; use std::path::{Path, PathBuf}; @@ -10,7 +9,7 @@ use serde::Serialize; use tokio::sync::mpsc; use hypr_listener2_core::{BatchErrorCode, BatchEvent}; -use owhisper_interface::stream::StreamResponse; +use owhisper_interface::batch_stream::BatchStreamEvent; use crate::OptTraceBuffer; use crate::app::AppContext; @@ -92,9 +91,9 @@ async fn start_batch( stt: SttOverrides, on_normalize_progress: Option<&mut dyn FnMut(f64)>, ) -> CliResult { + let resolved = resolve_config(None, stt).await?; let (normalized_input_dir, normalized_input_path) = normalize_input_file(input.path(), on_normalize_progress)?; - let resolved = resolve_config(None, stt).await?; let params = build_batch_params(&resolved, &normalized_input_path, keywords)?; let (batch_tx, batch_rx) = mpsc::unbounded_channel::(); @@ -149,13 +148,8 @@ fn build_batch_params( )) } -fn extract_stream_transcript(response: &StreamResponse) -> Option<&str> { - match response { - StreamResponse::TranscriptResponse { channel, .. } => { - channel.alternatives.first().map(|a| a.transcript.as_str()) - } - _ => None, - } +fn extract_stream_transcript(event: &BatchStreamEvent) -> Option<&str> { + event.text() } struct CollectedBatch { @@ -168,36 +162,32 @@ fn finish_batch( Result, tokio::task::JoinError, >, - batch_response: Option, - streamed_segments: Vec, failure: Option<(BatchErrorCode, String)>, started: std::time::Instant, ) -> CliResult { let result = task_result .map_err(|e| CliError::operation_failed("batch transcription", e.to_string()))?; - if let Err(error) = result { + let output = if let Ok(output) = result { + output + } else { + let error = result.err().unwrap(); let message = if let Some((code, message)) = failure { format!("{code:?}: {message}") } else { error.to_string() }; return Err(CliError::operation_failed("batch transcription", message)); - } - - let response = batch_response - .or_else(|| response::batch_response_from_streams(streamed_segments)) - .ok_or_else(|| { - CliError::operation_failed("batch transcription", "completed without a final response") - })?; + }; Ok(CollectedBatch { - response, + response: output.response, elapsed: started.elapsed(), }) } // -- Entry point -- +#[allow(clippy::unit_arg)] pub async fn run(ctx: &AppContext, args: Args) -> CliResult<()> { let format = args.format; let output_path = args.output.clone(); @@ -246,43 +236,30 @@ async fn run_json( } }; - let mut batch_response: Option = None; - let mut streamed_segments: Vec = Vec::new(); let mut failure: Option<(BatchErrorCode, String)> = None; while let Some(event) = handle.rx.recv().await { match event { BatchEvent::BatchStarted { .. } | BatchEvent::BatchCompleted { .. } => {} BatchEvent::BatchResponseStreamed { - response: streamed, - percentage, - .. + event: streamed, .. } => { let transcript = extract_stream_transcript(&streamed) .unwrap_or("") .to_string(); writer.emit(&TranscribeEvent::Progress { - percentage, + percentage: streamed.percentage(), transcript, })?; - streamed_segments.push(streamed); - } - BatchEvent::BatchResponse { response: next, .. } => { - batch_response = Some(next); } + BatchEvent::BatchResponse { .. } => {} BatchEvent::BatchFailed { code, error, .. } => { failure = Some((code, error)); } } } - let result = match finish_batch( - handle.task.await, - batch_response, - streamed_segments, - failure, - handle.started, - ) { + let result = match finish_batch(handle.task.await, failure, handle.started) { Ok(r) => r, Err(e) => { let _ = writer.emit(&TranscribeEvent::Failed { @@ -369,8 +346,6 @@ async fn run_pretty( #[cfg(not(feature = "standalone"))] let mut handle = start_batch(&args.input, args.keywords.clone(), stt, None).await?; - let mut batch_response: Option = None; - let mut streamed_segments: Vec = Vec::new(); let mut failure: Option<(BatchErrorCode, String)> = None; #[cfg(feature = "standalone")] @@ -411,25 +386,23 @@ async fn run_pretty( match event { BatchEvent::BatchStarted { .. } | BatchEvent::BatchCompleted { .. } => {} BatchEvent::BatchResponseStreamed { - response: streamed, - #[cfg(feature = "standalone")] - percentage, - .. + event: streamed, .. } => { #[cfg(feature = "standalone")] { - last_pct = percentage; - if let Some(t) = extract_stream_transcript(&streamed) { - if !t.is_empty() { - last_transcript = t.to_string(); - } + last_pct = streamed.percentage(); + if let Some(t) = extract_stream_transcript(&streamed) + && !t.is_empty() + { + last_transcript = t.to_string(); } } - streamed_segments.push(streamed); - } - BatchEvent::BatchResponse { response: next, .. } => { - batch_response = Some(next); + #[cfg(not(feature = "standalone"))] + { + let _ = streamed; + } } + BatchEvent::BatchResponse { .. } => {} BatchEvent::BatchFailed { code, error, .. } => { failure = Some((code, error)); } @@ -442,13 +415,7 @@ async fn run_pretty( .map_err(|e| CliError::operation_failed("clear viewport", e.to_string()))?; } - let result = finish_batch( - handle.task.await, - batch_response, - streamed_segments, - failure, - handle.started, - )?; + let result = finish_batch(handle.task.await, failure, handle.started)?; let response = &result.response; let pretty = output::format_pretty(response); diff --git a/apps/cli/src/commands/transcribe/output.rs b/apps/cli/src/commands/transcribe/output.rs index aa6ca86e1e..3ac5509c27 100644 --- a/apps/cli/src/commands/transcribe/output.rs +++ b/apps/cli/src/commands/transcribe/output.rs @@ -21,7 +21,7 @@ struct Segment<'a> { start: f64, end: f64, words: Vec<&'a str>, - channel: usize, + identity: usize, } pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> String { @@ -32,11 +32,11 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S text: &'a str, start: f64, end: f64, - speaker: usize, + identity: usize, } let mut all_words: Vec = Vec::new(); - for (ch_idx, channel) in response.results.channels.iter().enumerate() { + for (channel_idx, channel) in response.results.channels.iter().enumerate() { let Some(alt) = channel.alternatives.first() else { continue; }; @@ -48,7 +48,7 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S .unwrap_or(word.word.as_str()), start: word.start, end: word.end, - speaker: word.speaker.unwrap_or(ch_idx), + identity: word_identity(word, channel_idx, num_channels), }); } } @@ -61,7 +61,7 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S for word in &all_words { let should_split = segments .last() - .map(|seg| word.start - seg.end > PAUSE_THRESHOLD_SECS || word.speaker != seg.channel) + .map(|seg| word.start - seg.end > PAUSE_THRESHOLD_SECS || word.identity != seg.identity) .unwrap_or(true); if should_split { @@ -69,7 +69,7 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S start: word.start, end: word.end, words: vec![word.text], - channel: word.speaker, + identity: word.identity, }); } else if let Some(seg) = segments.last_mut() { seg.end = word.end; @@ -83,7 +83,7 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S let term_width = textwrap::termwidth(); let show_speaker = - num_channels > 1 || segments.iter().any(|s| s.channel != segments[0].channel); + num_channels > 1 || segments.iter().any(|s| s.identity != segments[0].identity); segments .iter() @@ -99,7 +99,7 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S let label = format!("{} ", timestamp); let text = seg.words.join(" "); let text = if show_speaker { - text.color(speaker_color(seg.channel)).to_string() + text.color(speaker_color(seg.identity)).to_string() } else { text }; @@ -137,6 +137,18 @@ pub(super) fn format_pretty(response: &owhisper_interface::batch::Response) -> S .join("\n\n") } +fn word_identity( + word: &owhisper_interface::batch::Word, + channel_idx: usize, + total_channels: usize, +) -> usize { + if total_channels > 1 { + channel_idx + } else { + word.speaker.unwrap_or(word.channel.max(0) as usize) + } +} + pub(super) fn extract_transcript(response: &owhisper_interface::batch::Response) -> String { response .results @@ -148,3 +160,78 @@ pub(super) fn extract_transcript(response: &owhisper_interface::batch::Response) .collect::>() .join("\n") } + +#[cfg(test)] +mod tests { + use owhisper_interface::batch; + + use super::*; + + fn response_with_channels(channel_words: Vec>) -> batch::Response { + batch::Response { + metadata: serde_json::json!({}), + results: batch::Results { + channels: channel_words + .into_iter() + .map(|words| batch::Channel { + alternatives: vec![batch::Alternatives { + transcript: words + .iter() + .map(|word| word.word.as_str()) + .collect::>() + .join(" "), + confidence: 1.0, + words, + }], + }) + .collect(), + }, + } + } + + fn word(text: &str, start: f64, end: f64, channel: i32, speaker: Option) -> batch::Word { + batch::Word { + word: text.to_string(), + start, + end, + confidence: 1.0, + channel, + speaker, + punctuated_word: Some(text.to_string()), + } + } + + #[test] + fn pretty_output_splits_multichannel_words_by_channel() { + colored::control::set_override(false); + + let response = response_with_channels(vec![ + vec![word("left", 0.0, 0.4, 0, None)], + vec![word("right", 0.1, 0.5, 0, None)], + ]); + + let pretty = format_pretty(&response); + let blocks = pretty.split("\n\n").collect::>(); + + assert_eq!(blocks.len(), 2); + assert!(blocks[0].contains("left")); + assert!(blocks[1].contains("right")); + } + + #[test] + fn pretty_output_splits_single_channel_words_by_speaker() { + colored::control::set_override(false); + + let response = response_with_channels(vec![vec![ + word("hello", 0.0, 0.4, 0, Some(0)), + word("again", 0.45, 0.8, 0, Some(1)), + ]]); + + let pretty = format_pretty(&response); + let blocks = pretty.split("\n\n").collect::>(); + + assert_eq!(blocks.len(), 2); + assert!(blocks[0].contains("hello")); + assert!(blocks[1].contains("again")); + } +} diff --git a/apps/cli/src/commands/transcribe/response.rs b/apps/cli/src/commands/transcribe/response.rs deleted file mode 100644 index 4976b22f18..0000000000 --- a/apps/cli/src/commands/transcribe/response.rs +++ /dev/null @@ -1,64 +0,0 @@ -use owhisper_interface::batch; -use owhisper_interface::stream::StreamResponse; - -pub(super) fn batch_response_from_streams( - segments: Vec, -) -> Option { - if segments.is_empty() { - return None; - } - - let mut all_words: Vec = Vec::new(); - let mut all_transcripts: Vec = Vec::new(); - let mut total_confidence = 0.0; - let mut max_end = 0.0_f64; - let mut count = 0usize; - - for segment in segments { - let StreamResponse::TranscriptResponse { - channel, - start, - duration, - .. - } = segment - else { - continue; - }; - - let Some(alt) = channel.alternatives.into_iter().next() else { - continue; - }; - - let text = alt.transcript.trim().to_string(); - if text.is_empty() { - continue; - } - - let words: Vec = alt.words.into_iter().map(batch::Word::from).collect(); - all_words.extend(words); - all_transcripts.push(text); - total_confidence += alt.confidence; - max_end = max_end.max(start + duration); - count += 1; - } - - if count == 0 { - return None; - } - - let transcript = all_transcripts.join(" "); - let avg_confidence = total_confidence / count as f64; - - Some(batch::Response { - metadata: serde_json::json!({ "duration": max_end }), - results: batch::Results { - channels: vec![batch::Channel { - alternatives: vec![batch::Alternatives { - transcript, - confidence: avg_confidence, - words: all_words, - }], - }], - }, - }) -} diff --git a/apps/cli/src/config/paths.rs b/apps/cli/src/config/paths.rs index 791143191c..b632daf6c2 100644 --- a/apps/cli/src/config/paths.rs +++ b/apps/cli/src/config/paths.rs @@ -2,6 +2,7 @@ use std::path::{Path, PathBuf}; #[derive(Clone, Debug)] pub struct AppPaths { + #[allow(dead_code)] pub base: PathBuf, pub models_base: PathBuf, } diff --git a/apps/cli/src/config/settings.rs b/apps/cli/src/config/settings.rs index 63f045b642..98e98a4842 100644 --- a/apps/cli/src/config/settings.rs +++ b/apps/cli/src/config/settings.rs @@ -16,54 +16,6 @@ pub struct Settings { pub stt_providers: HashMap, } -pub async fn load_settings(pool: &SqlitePool) -> Option { - let all = hypr_db_app::load_all_settings(pool).await.ok()?; - let setting_map: HashMap = all.into_iter().collect(); - - let current_stt_provider = setting_map - .get("current_stt_provider") - .filter(|v| !v.is_empty()) - .cloned(); - let current_stt_model = setting_map - .get("current_stt_model") - .filter(|v| !v.is_empty()) - .cloned(); - - let stt_connections = hypr_db_app::list_connections(pool, "stt").await.ok()?; - let stt_providers = connections_to_provider_map(stt_connections); - - if current_stt_provider.is_none() && stt_providers.is_empty() { - return None; - } - - Some(Settings { - current_stt_provider, - current_stt_model, - stt_providers, - }) -} - -fn connections_to_provider_map( - connections: Vec, -) -> HashMap { - connections - .into_iter() - .map(|c| { - let base_url = if c.base_url.is_empty() { - None - } else { - Some(c.base_url) - }; - let api_key = if c.api_key.is_empty() { - None - } else { - Some(c.api_key) - }; - (c.provider_id, ProviderConfig { base_url, api_key }) - }) - .collect() -} - pub async fn migrate_json_settings_to_db(pool: &SqlitePool, base_path: &Path) { let has_settings = hypr_db_app::load_all_settings(pool) .await diff --git a/apps/cli/src/error.rs b/apps/cli/src/error.rs index c425217a1e..568533a495 100644 --- a/apps/cli/src/error.rs +++ b/apps/cli/src/error.rs @@ -67,20 +67,6 @@ impl CliError { } } - pub fn invalid_argument_with_hint( - name: &'static str, - value: impl Into, - reason: impl Into, - hint: impl Into, - ) -> Self { - Self::InvalidArgument { - name, - value: value.into(), - reason: reason.into(), - hint: Some(hint.into()), - } - } - pub fn operation_failed(action: &'static str, reason: impl Into) -> Self { Self::OperationFailed { action, diff --git a/apps/cli/src/main.rs b/apps/cli/src/main.rs index 454c2e926f..5be84e7ab5 100644 --- a/apps/cli/src/main.rs +++ b/apps/cli/src/main.rs @@ -13,6 +13,7 @@ use crate::error::CliResult; use clap::Parser; #[tokio::main] +#[allow(clippy::let_unit_value)] async fn main() { let cli = Cli::parse(); @@ -82,7 +83,7 @@ fn init_tracing(cli: &Cli) -> OptTraceBuffer { #[cfg(feature = "standalone")] return None; #[cfg(not(feature = "standalone"))] - return (); + return; } fn init_tracing_stderr(level: tracing_subscriber::filter::LevelFilter) { diff --git a/apps/cli/src/output.rs b/apps/cli/src/output.rs index 92aaf862a2..1bdd9904bd 100644 --- a/apps/cli/src/output.rs +++ b/apps/cli/src/output.rs @@ -1,19 +1,8 @@ +use crate::error::{CliError, CliResult}; use std::io::{IsTerminal, Write}; use std::path::Path; -use std::time::Duration; - -use crate::error::{CliError, CliResult}; - -pub fn format_hhmmss(duration: Duration) -> String { - let secs = duration.as_secs(); - format!( - "{:02}:{:02}:{:02}", - secs / 3600, - (secs % 3600) / 60, - secs % 60 - ) -} +#[allow(dead_code)] pub fn format_timestamp_ms(ms: i64) -> String { let total_secs = (ms / 1000).max(0); let mins = total_secs / 60; diff --git a/apps/cli/src/stt/config.rs b/apps/cli/src/stt/config.rs index 9598ebd688..b9823be427 100644 --- a/apps/cli/src/stt/config.rs +++ b/apps/cli/src/stt/config.rs @@ -332,12 +332,24 @@ fn canonical_cactus_name(name: &str) -> String { } } -fn default_cactus_model() -> CactusSttModel { - if cfg!(target_arch = "aarch64") && cfg!(target_os = "macos") { - CactusSttModel::WhisperSmallInt8Apple - } else { - CactusSttModel::WhisperSmallInt8 - } +fn missing_cactus_model_error() -> CliError { + CliError::required_argument_with_hint( + "--model", + format!( + "Pass --model explicitly for --provider cactus. Valid models: {}", + cactus_model_names().join(", ") + ), + ) +} + +fn cactus_model_names() -> Vec<&'static str> { + LocalModel::all() + .iter() + .filter_map(|model| match model { + LocalModel::Cactus(_) => Some(model.cli_name()), + _ => None, + }) + .collect() } fn resolve_cactus_model( @@ -348,23 +360,19 @@ fn resolve_cactus_model( return Err(unsupported_cactus_error()); } - let model = match name { - Some(name) => { - let canonical = canonical_cactus_name(name); - LocalModel::all() - .into_iter() - .find_map(|model| match model { - LocalModel::Cactus(cactus) - if model.cli_name() == name || model.cli_name() == canonical => - { - Some(cactus) - } - _ => None, - }) - .ok_or_else(|| not_found_cactus_model(models_base, name, false))? - } - None => default_cactus_model(), - }; + let name = name.ok_or_else(missing_cactus_model_error)?; + let canonical = canonical_cactus_name(name); + let model = LocalModel::all() + .into_iter() + .find_map(|model| match model { + LocalModel::Cactus(cactus) + if model.cli_name() == name || model.cli_name() == canonical => + { + Some(cactus) + } + _ => None, + }) + .ok_or_else(|| not_found_cactus_model(models_base, name, false))?; let model_path = LocalModel::Cactus(model.clone()).install_path(models_base); if !model_path.exists() { diff --git a/apps/cli/src/stt/provider.rs b/apps/cli/src/stt/provider.rs index f830b26b2d..952be7e577 100644 --- a/apps/cli/src/stt/provider.rs +++ b/apps/cli/src/stt/provider.rs @@ -87,7 +87,7 @@ impl SttProvider { self.meta().cloud_provider } - pub(crate) fn to_batch_provider(&self) -> BatchProvider { + pub(crate) fn to_batch_provider(self) -> BatchProvider { self.meta().batch_provider } } diff --git a/apps/cli/src/tui/mod.rs b/apps/cli/src/tui/mod.rs index 923c0b3362..88a2da1072 100644 --- a/apps/cli/src/tui/mod.rs +++ b/apps/cli/src/tui/mod.rs @@ -72,9 +72,12 @@ impl BackgroundInput { Some(InputAction::Interrupt) => { // Raw mode swallows Ctrl+C, so send SIGINT back to the // process and let the async runtime shut down naturally. + #[cfg(unix)] unsafe { libc::kill(libc::getpid(), libc::SIGINT); } + #[cfg(not(unix))] + std::process::exit(130); } None => {} } diff --git a/apps/desktop/src/store/zustand/listener/batch.ts b/apps/desktop/src/store/zustand/listener/batch.ts index a6f825c3e5..07805cd2f9 100644 --- a/apps/desktop/src/store/zustand/listener/batch.ts +++ b/apps/desktop/src/store/zustand/listener/batch.ts @@ -1,6 +1,6 @@ import type { StoreApi } from "zustand"; -import type { BatchResponse, StreamResponse } from "@hypr/plugin-listener2"; +import type { BatchResponse, BatchStreamEvent } from "@hypr/plugin-listener2"; import type { BatchPersistCallback } from "./transcript"; import { transformWordEntries } from "./utils"; @@ -35,8 +35,7 @@ export type BatchActions = { handleBatchResponse: (sessionId: string, response: BatchResponse) => void; handleBatchResponseStreamed: ( sessionId: string, - response: StreamResponse, - percentage: number, + event: BatchStreamEvent, ) => void; handleBatchFailed: (sessionId: string, error: string) => void; updateBatchProgress: (sessionId: string, percentage: number) => void; @@ -114,8 +113,9 @@ export const createBatchSlice = ( }); }, - handleBatchResponseStreamed: (sessionId, response, percentage) => { - const isComplete = response.type === "Results" && response.from_finalize; + handleBatchResponseStreamed: (sessionId, event) => { + const percentage = getBatchStreamPercentage(event); + const isComplete = event.type === "result" || event.type === "terminal"; set((state) => ({ ...state, @@ -134,7 +134,7 @@ export const createBatchSlice = ( wordsByChannel: {}, hintsByChannel: {}, }, - response, + event, ), }, })); @@ -255,8 +255,13 @@ function mergeBatchPreview( wordsByChannel: Record; hintsByChannel: Record; }, - response: StreamResponse, + event: BatchStreamEvent, ) { + if (event.type !== "segment") { + return preview; + } + + const response = event.response; if (response.type !== "Results") { return preview; } @@ -337,3 +342,16 @@ function mergeBatchPreview( }, }; } + +function getBatchStreamPercentage(event: BatchStreamEvent): number { + switch (event.type) { + case "progress": + case "segment": + return event.percentage; + case "result": + case "terminal": + return 1; + case "error": + return 0; + } +} diff --git a/apps/desktop/src/store/zustand/listener/general-batch.ts b/apps/desktop/src/store/zustand/listener/general-batch.ts index e37ee70a86..849c799e15 100644 --- a/apps/desktop/src/store/zustand/listener/general-batch.ts +++ b/apps/desktop/src/store/zustand/listener/general-batch.ts @@ -86,11 +86,7 @@ export const runBatchSession = async ( } if (payload.type === "batchProgress") { - get().handleBatchResponseStreamed( - sessionId, - payload.response, - payload.percentage, - ); + get().handleBatchResponseStreamed(sessionId, payload.event); return; } diff --git a/apps/desktop/src/store/zustand/listener/general.test.ts b/apps/desktop/src/store/zustand/listener/general.test.ts index db5de971bb..a9070a6cf7 100644 --- a/apps/desktop/src/store/zustand/listener/general.test.ts +++ b/apps/desktop/src/store/zustand/listener/general.test.ts @@ -40,35 +40,13 @@ describe("General Listener Slice", () => { const sessionId = "session-456"; const { handleBatchResponseStreamed, getSessionMode } = store.getState(); - const mockResponse = { - type: "Results" as const, - start: 0, - duration: 5, - is_final: false, - speech_final: false, - from_finalize: false, - channel: { - alternatives: [ - { - transcript: "test", - words: [], - confidence: 0.9, - }, - ], - }, - metadata: { - request_id: "test-request", - model_info: { - name: "test-model", - version: "1.0", - arch: "test-arch", - }, - model_uuid: "test-uuid", - }, - channel_index: [0], + const mockEvent = { + type: "progress" as const, + percentage: 0.5, + partial_text: "test", }; - handleBatchResponseStreamed(sessionId, mockResponse, 0.5); + handleBatchResponseStreamed(sessionId, mockEvent); expect(getSessionMode(sessionId)).toBe("running_batch"); }); }); @@ -79,46 +57,50 @@ describe("General Listener Slice", () => { const { handleBatchResponseStreamed, clearBatchSession } = store.getState(); - const mockResponse = { - type: "Results" as const, - start: 0, - duration: 5, - is_final: false, - speech_final: false, - from_finalize: false, - channel: { - alternatives: [ - { - transcript: "test", - languages: [], - words: [ - { - word: "test", - punctuated_word: "test", - start: 0, - end: 0.5, - confidence: 0.9, - speaker: null, - language: null, - }, - ], - confidence: 0.9, + const mockEvent = { + type: "segment" as const, + percentage: 0.5, + response: { + type: "Results" as const, + start: 0, + duration: 5, + is_final: false, + speech_final: false, + from_finalize: false, + channel: { + alternatives: [ + { + transcript: "test", + languages: [], + words: [ + { + word: "test", + punctuated_word: "test", + start: 0, + end: 0.5, + confidence: 0.9, + speaker: null, + language: null, + }, + ], + confidence: 0.9, + }, + ], + }, + metadata: { + request_id: "test-request", + model_info: { + name: "test-model", + version: "1.0", + arch: "test-arch", }, - ], - }, - metadata: { - request_id: "test-request", - model_info: { - name: "test-model", - version: "1.0", - arch: "test-arch", + model_uuid: "test-uuid", }, - model_uuid: "test-uuid", + channel_index: [0], }, - channel_index: [0], }; - handleBatchResponseStreamed(sessionId, mockResponse, 0.5); + handleBatchResponseStreamed(sessionId, mockEvent); expect(store.getState().batch[sessionId]).toEqual({ percentage: 0.5, isComplete: false, diff --git a/apps/desktop/src/store/zustand/listener/utils.ts b/apps/desktop/src/store/zustand/listener/utils.ts index c09ceca6d8..010f98845f 100644 --- a/apps/desktop/src/store/zustand/listener/utils.ts +++ b/apps/desktop/src/store/zustand/listener/utils.ts @@ -34,6 +34,7 @@ export type WordEntry = { punctuated_word?: string | null; start: number; end: number; + channel?: number; speaker?: number | null; }; @@ -59,7 +60,7 @@ export function transformWordEntries( text, start_ms: Math.round(word.start * 1000), end_ms: Math.round(word.end * 1000), - channel, + channel: typeof word.channel === "number" ? word.channel : channel, }); if (typeof word.speaker === "number") { diff --git a/crates/activity-capture-macos/Cargo.toml b/crates/activity-capture-macos/Cargo.toml index ba3ca4063e..de2677a541 100644 --- a/crates/activity-capture-macos/Cargo.toml +++ b/crates/activity-capture-macos/Cargo.toml @@ -4,13 +4,13 @@ version = "0.1.0" edition = "2024" [dependencies] -block2 = { workspace = true } futures-core = { workspace = true } hypr-activity-capture-interface = { workspace = true } tokio = { workspace = true, features = ["sync"] } tokio-stream = { workspace = true } [target.'cfg(target_os = "macos")'.dependencies] +block2 = { workspace = true } objc2 = { workspace = true } objc2-app-kit = { workspace = true, features = ["NSWorkspace", "NSRunningApplication"] } objc2-application-services = { workspace = true } diff --git a/crates/audio-actual/src/rt_ring.rs b/crates/audio-actual/src/rt_ring.rs index 84d93aedf5..cd83ace601 100644 --- a/crates/audio-actual/src/rt_ring.rs +++ b/crates/audio-actual/src/rt_ring.rs @@ -197,9 +197,9 @@ where let convert_count = count.min(vacant); - for i in 0..convert_count { + for (i, slot) in scratch[..convert_count].iter_mut().enumerate() { let byte_offset = (offset + i) * frame_size; - scratch[i] = f32::from_le_bytes([ + *slot = f32::from_le_bytes([ data[byte_offset], data[byte_offset + 1], data[byte_offset + 2], diff --git a/crates/audio-actual/src/speaker/windows.rs b/crates/audio-actual/src/speaker/windows.rs index bf398fccd6..d156e11b26 100644 --- a/crates/audio-actual/src/speaker/windows.rs +++ b/crates/audio-actual/src/speaker/windows.rs @@ -5,7 +5,7 @@ use hypr_audio_utils::{pcm_i16_to_f32, pcm_i32_to_f32}; use pin_project::pin_project; use ringbuf::{ HeapCons, HeapProd, HeapRb, - traits::{Producer, Split}, + traits::{Observer, Producer, Split}, }; use std::collections::VecDeque; use std::sync::Arc; diff --git a/crates/audio-sync/src/probe.rs b/crates/audio-sync/src/probe.rs index 3b00e69e6a..23bb306670 100644 --- a/crates/audio-sync/src/probe.rs +++ b/crates/audio-sync/src/probe.rs @@ -417,11 +417,10 @@ impl SyncProbe { return self.snapshot(SyncProbeState::Locked, Some(lag_samples), confidence); } - if let Some(center) = self.acquisition_center() { - if (lag_samples - center).unsigned_abs() > self.tuning.acquire_cluster_tolerance_samples - { - self.acquisition_entries.clear(); - } + if let Some(center) = self.acquisition_center() + && (lag_samples - center).unsigned_abs() > self.tuning.acquire_cluster_tolerance_samples + { + self.acquisition_entries.clear(); } self.push_acquisition_entry(Some(lag_samples)); @@ -575,7 +574,7 @@ fn median_isize(iter: impl IntoIterator) -> Option { values.sort_unstable(); let mid = values.len() / 2; - Some(if values.len() % 2 == 0 { + Some(if values.len().is_multiple_of(2) { (values[mid - 1] + values[mid]) / 2 } else { values[mid] diff --git a/crates/db-app/src/lib.rs b/crates/db-app/src/lib.rs index 52050f454c..5554ab03b8 100644 --- a/crates/db-app/src/lib.rs +++ b/crates/db-app/src/lib.rs @@ -1,4 +1,5 @@ #![forbid(unsafe_code)] +#![allow(clippy::too_many_arguments)] #[cfg(feature = "cli")] pub mod human_cli; diff --git a/crates/db-core2/src/lib.rs b/crates/db-core2/src/lib.rs index ec5101befc..064d35c1ca 100644 --- a/crates/db-core2/src/lib.rs +++ b/crates/db-core2/src/lib.rs @@ -49,7 +49,7 @@ impl Db3 { pub async fn connect_local_plain(path: impl AsRef) -> Result { if let Some(parent) = path.as_ref().parent() { - std::fs::create_dir_all(parent).map_err(|e| sqlx::Error::Io(e))?; + std::fs::create_dir_all(parent).map_err(sqlx::Error::Io)?; } let options = SqliteConnectOptions::new() .filename(path) diff --git a/crates/listener2-core/src/batch/accumulator.rs b/crates/listener2-core/src/batch/accumulator.rs index 375f8633a8..03ebb16b55 100644 --- a/crates/listener2-core/src/batch/accumulator.rs +++ b/crates/listener2-core/src/batch/accumulator.rs @@ -1,5 +1,7 @@ use std::collections::BTreeMap; +use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::stream::StreamResponse; use super::{BatchRunMode, BatchRunOutput}; @@ -10,6 +12,7 @@ pub(super) struct StreamBatchAccumulator { max_duration_secs: f64, terminal_duration_secs: Option, terminal_channels: Option, + final_response: Option, } #[derive(Default)] @@ -26,74 +29,100 @@ impl StreamBatchAccumulator { Self::default() } - pub(super) fn observe(&mut self, response: &StreamResponse) { - match response { - StreamResponse::TranscriptResponse { - start, - duration, - from_finalize, - channel, - channel_index, - .. - } => { - self.max_duration_secs = self.max_duration_secs.max((*start + *duration).max(0.0)); - - let channel_id = channel_index.first().copied().unwrap_or(0); - let Some(alternative) = channel.alternatives.first() else { - return; - }; - - let state = self.channels.entry(channel_id).or_default(); - let transcript = alternative.transcript.trim(); - - if !alternative.words.is_empty() { - let words = alternative - .words - .iter() - .cloned() - .map(owhisper_interface::batch::Word::from) - .collect::>(); - - let should_replace = *from_finalize - && *start <= 0.0 - && words.last().map(|word| word.end).unwrap_or_default() - >= state.words.last().map(|word| word.end).unwrap_or_default(); - - if should_replace { - state.words = words; - } else { - append_non_overlapping_words(&mut state.words, words); + pub(super) fn observe(&mut self, event: &BatchStreamEvent) { + match event { + BatchStreamEvent::Progress { .. } => {} + BatchStreamEvent::Segment { + response, + percentage: _, + } => match response { + StreamResponse::TranscriptResponse { + start, + duration, + from_finalize, + channel, + channel_index, + .. + } => { + self.max_duration_secs = + self.max_duration_secs.max((*start + *duration).max(0.0)); + + let channel_id = channel_index.first().copied().unwrap_or(0); + let Some(alternative) = channel.alternatives.first() else { + return; + }; + + let state = self.channels.entry(channel_id).or_default(); + let transcript = alternative.transcript.trim(); + + if !alternative.words.is_empty() { + let mut words = alternative + .words + .iter() + .cloned() + .map(owhisper_interface::batch::Word::from) + .collect::>(); + for word in &mut words { + word.channel = channel_id; + } + + let should_replace = *from_finalize + && *start <= 0.0 + && words.last().map(|word| word.end).unwrap_or_default() + >= state.words.last().map(|word| word.end).unwrap_or_default(); + + if should_replace { + state.words = words; + } else { + append_non_overlapping_words(&mut state.words, words); + } } - } - if !transcript.is_empty() { - if *from_finalize { - state.final_transcript = Some(transcript.to_string()); - } else if state - .transcript_segments - .last() - .is_none_or(|existing| existing != transcript) - { - state.transcript_segments.push(transcript.to_string()); + if !transcript.is_empty() { + if *from_finalize { + state.final_transcript = Some(transcript.to_string()); + } else if state + .transcript_segments + .last() + .is_none_or(|existing| existing != transcript) + { + state.transcript_segments.push(transcript.to_string()); + } } - } - if alternative.confidence.is_finite() { - state.confidence_sum += alternative.confidence; - state.confidence_count += 1; + if alternative.confidence.is_finite() { + state.confidence_sum += alternative.confidence; + state.confidence_count += 1; + } } - } - StreamResponse::TerminalResponse { + StreamResponse::TerminalResponse { .. } => {} + StreamResponse::ErrorResponse { .. } + | StreamResponse::SpeechStartedResponse { .. } + | StreamResponse::UtteranceEndResponse { .. } => {} + _ => {} + }, + BatchStreamEvent::Terminal { duration, channels, .. } => { self.terminal_duration_secs = Some(*duration); self.terminal_channels = Some(*channels); } - _ => {} + BatchStreamEvent::Result { response } => { + self.final_response = Some(response.clone()); + } + BatchStreamEvent::Error { .. } => {} } } pub(super) fn finish(self, session_id: &str) -> BatchRunOutput { + if let Some(response) = self.final_response { + return BatchRunOutput { + session_id: session_id.to_string(), + mode: BatchRunMode::Streamed, + response, + }; + } + let channel_count = self .terminal_channels .map(|count| count as usize) @@ -183,6 +212,7 @@ fn append_non_overlapping_words( #[cfg(test)] mod test { + use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::stream::{Alternatives, Channel, Metadata, ModelInfo, Word}; use super::*; @@ -191,17 +221,23 @@ mod test { fn streamed_accumulator_uses_finalize_transcript_without_duplication() { let mut accumulator = StreamBatchAccumulator::new(); - accumulator.observe(&transcript_response( - 0.0, - 2.0, - false, - "hello world", - vec![ - stream_word("hello", 0.0, 0.8), - stream_word("world", 0.9, 1.5), - ], + accumulator.observe(&segment_response( + transcript_response( + 0.0, + 2.0, + false, + "hello world", + vec![ + stream_word("hello", 0.0, 0.8), + stream_word("world", 0.9, 1.5), + ], + ), + 0.5, + )); + accumulator.observe(&segment_response( + transcript_response(0.0, 2.0, true, "hello world", vec![]), + 1.0, )); - accumulator.observe(&transcript_response(0.0, 2.0, true, "hello world", vec![])); let output = accumulator.finish("session-1"); let channel = &output.response.results.channels[0].alternatives[0]; @@ -215,22 +251,28 @@ mod test { fn streamed_accumulator_replaces_words_with_full_finalize_snapshot() { let mut accumulator = StreamBatchAccumulator::new(); - accumulator.observe(&transcript_response( - 1.0, - 1.0, - false, - "world", - vec![stream_word("world", 1.0, 1.5)], + accumulator.observe(&segment_response( + transcript_response( + 1.0, + 1.0, + false, + "world", + vec![stream_word("world", 1.0, 1.5)], + ), + 0.5, )); - accumulator.observe(&transcript_response( - 0.0, - 2.0, - true, - "hello world", - vec![ - stream_word("hello", 0.0, 0.8), - stream_word("world", 1.0, 1.5), - ], + accumulator.observe(&segment_response( + transcript_response( + 0.0, + 2.0, + true, + "hello world", + vec![ + stream_word("hello", 0.0, 0.8), + stream_word("world", 1.0, 1.5), + ], + ), + 1.0, )); let output = accumulator.finish("session-2"); @@ -241,6 +283,13 @@ mod test { assert_eq!(channel.words[0].word, "hello"); } + fn segment_response(response: StreamResponse, percentage: f64) -> BatchStreamEvent { + BatchStreamEvent::Segment { + response, + percentage, + } + } + fn transcript_response( start: f64, duration: f64, diff --git a/crates/listener2-core/src/batch/actor.rs b/crates/listener2-core/src/batch/actor.rs index fc60c2a515..96a8bcb04b 100644 --- a/crates/listener2-core/src/batch/actor.rs +++ b/crates/listener2-core/src/batch/actor.rs @@ -2,13 +2,14 @@ use std::sync::{Arc, Mutex}; use std::time::Duration; use owhisper_client::StreamingBatchStream; +use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::stream::StreamResponse; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SpawnErr}; use tracing::Instrument; use super::accumulator::StreamBatchAccumulator; use super::bootstrap::{notify_start_result, spawn_batch_task}; -use super::{BatchParams, BatchRunMode, BatchRunOutput, format_user_friendly_error, session_span}; +use super::{BatchParams, BatchRunOutput, format_user_friendly_error, session_span}; use crate::{BatchEvent, BatchRuntime}; const BATCH_STREAM_TIMEOUT_SECS: u64 = 30; @@ -29,6 +30,7 @@ pub(super) async fn run_batch_streaming( let args = BatchArgs { runtime: runtime.clone(), + provider: params.provider, file_path: params.file_path, base_url: params.base_url, api_key: params.api_key, @@ -92,22 +94,19 @@ pub(super) async fn run_batch_streaming( .await } -fn is_completion_response(response: &StreamResponse) -> bool { +fn is_completion_event(event: &BatchStreamEvent) -> bool { matches!( - response, - StreamResponse::TranscriptResponse { - from_finalize: true, - .. - } | StreamResponse::TerminalResponse { .. } + event, + BatchStreamEvent::Result { .. } | BatchStreamEvent::Terminal { .. } ) } -fn provider_error_from_response(response: &StreamResponse) -> Option<(&str, &str, Option)> { - let StreamResponse::ErrorResponse { +fn provider_error_from_event(event: &BatchStreamEvent) -> Option<(&str, &str, Option)> { + let BatchStreamEvent::Error { provider, error_message, error_code, - } = response + } = event else { return None; }; @@ -117,11 +116,7 @@ fn provider_error_from_response(response: &StreamResponse) -> Option<(&str, &str #[allow(clippy::enum_variant_names)] pub(super) enum BatchMsg { - StreamResponse { - response: Box, - percentage: f64, - final_batch_response: Option, - }, + StreamResponse { event: Box }, StreamError(crate::BatchFailure), StreamEnded, StreamStartFailed(crate::BatchFailure), @@ -135,6 +130,7 @@ type BatchDoneNotifier = #[derive(Clone)] pub(super) struct BatchArgs { pub(super) runtime: Arc, + pub(super) provider: super::BatchProvider, pub(super) file_path: String, pub(super) base_url: String, pub(super) api_key: String, @@ -152,15 +148,13 @@ struct BatchState { done_notifier: BatchDoneNotifier, final_result: Option>, accumulator: StreamBatchAccumulator, - stashed_final: Option, } impl BatchState { - fn emit_streamed(&self, response: StreamResponse, percentage: f64) { + fn emit_streamed(&self, event: BatchStreamEvent) { self.runtime.emit(BatchEvent::BatchResponseStreamed { session_id: self.session_id.clone(), - response, - percentage, + event, }); } } @@ -199,7 +193,6 @@ impl Actor for BatchActor { done_notifier: args.done_notifier, final_result: None, accumulator: StreamBatchAccumulator::new(), - stashed_final: None, }) } @@ -228,17 +221,10 @@ impl Actor for BatchActor { state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match message { - BatchMsg::StreamResponse { - response, - percentage, - final_batch_response, - } => { + BatchMsg::StreamResponse { event } => { tracing::info!("batch stream response received"); - state.accumulator.observe(&response); - state.emit_streamed(*response, percentage); - if let Some(final_resp) = final_batch_response { - state.stashed_final = Some(final_resp); - } + state.accumulator.observe(&event); + state.emit_streamed(*event); } BatchMsg::StreamStartFailed(error) => { tracing::error!("batch_stream_start_failed: {}", error); @@ -252,15 +238,7 @@ impl Actor for BatchActor { } BatchMsg::StreamEnded => { tracing::info!("batch_stream_ended"); - let output = if let Some(response) = state.stashed_final.take() { - BatchRunOutput { - session_id: state.session_id.clone(), - mode: BatchRunMode::Streamed, - response, - } - } else { - std::mem::take(&mut state.accumulator).finish(&state.session_id) - }; + let output = std::mem::take(&mut state.accumulator).finish(&state.session_id); state.final_result = Some(Ok(output)); myself.stop(None); } @@ -306,9 +284,14 @@ pub(super) async fn process_provider_stream( context: &str, ) { futures_util::pin_mut!(stream); - process_stream_loop(&mut stream, myself, shutdown_rx, context, 1, |event| { - (event.response, event.percentage, event.final_batch_response) - }) + process_stream_loop( + &mut stream, + myself, + shutdown_rx, + context, + 1, + std::convert::identity, + ) .await; } @@ -328,10 +311,7 @@ pub(super) async fn process_batch_stream( shutdown_rx, "batch stream", expected_completions, - |response| { - let percentage = compute_percentage(&response, audio_duration_secs); - (response, percentage, None) - }, + |response| batch_event_from_stream_response(response, audio_duration_secs), ) .await; } @@ -346,13 +326,7 @@ async fn process_stream_loop( ) where S: futures_util::Stream>, E: std::fmt::Debug, - F: FnMut( - Item, - ) -> ( - StreamResponse, - f64, - Option, - ), + F: FnMut(Item) -> BatchStreamEvent, { let mut response_count = 0; let response_timeout = Duration::from_secs(BATCH_STREAM_TIMEOUT_SECS); @@ -377,22 +351,22 @@ async fn process_stream_loop( match result { Ok(Some(Ok(item))) => { response_count += 1; - let (response, percentage, final_batch_response) = into_response(item); + let event = into_response(item); - let is_from_finalize = matches!( - &response, - StreamResponse::TranscriptResponse { from_finalize, .. } if *from_finalize - ); - let is_completion = is_completion_response(&response); + let is_completion = is_completion_event(&event); tracing::info!( "{context}: response #{}{}", response_count, - if is_from_finalize { " (from_finalize)" } else { "" } + if matches!(&event, BatchStreamEvent::Result { .. }) { + " (result)" + } else { + "" + } ); if let Some((provider, error_message, error_code)) = - provider_error_from_response(&response) + provider_error_from_event(&event) { tracing::error!( hyprnote.stt.provider.name = %provider, @@ -411,11 +385,14 @@ async fn process_stream_loop( break; } - send_actor_message(&myself, BatchMsg::StreamResponse { - response: Box::new(response), - percentage, - final_batch_response, - }, context, "stream response"); + send_actor_message( + &myself, + BatchMsg::StreamResponse { + event: Box::new(event), + }, + context, + "stream response", + ); if is_completion { completions_seen += 1; @@ -512,6 +489,44 @@ fn compute_percentage(response: &StreamResponse, audio_duration_secs: f64) -> f6 } } +fn batch_event_from_stream_response( + response: StreamResponse, + audio_duration_secs: f64, +) -> BatchStreamEvent { + let percentage = compute_percentage(&response, audio_duration_secs); + + match response { + StreamResponse::TranscriptResponse { .. } => BatchStreamEvent::Segment { + response, + percentage, + }, + StreamResponse::TerminalResponse { + request_id, + created, + duration, + channels, + } => BatchStreamEvent::Terminal { + request_id, + created, + duration, + channels, + }, + StreamResponse::ErrorResponse { + error_code, + error_message, + provider, + } => BatchStreamEvent::Error { + error_code, + error_message, + provider, + }, + other => BatchStreamEvent::Segment { + response: other, + percentage, + }, + } +} + fn transcript_end_from_response(response: &StreamResponse) -> Option { let StreamResponse::TranscriptResponse { start, @@ -538,63 +553,41 @@ fn transcript_end_from_response(response: &StreamResponse) -> Option { #[cfg(test)] mod test { - use owhisper_interface::stream::{Alternatives, Channel, Metadata, ModelInfo}; - use super::*; #[test] - fn completion_response_from_finalize() { - let response = StreamResponse::TranscriptResponse { - start: 0.0, - duration: 0.1, - is_final: true, - speech_final: true, - from_finalize: true, - channel: Channel { - alternatives: vec![Alternatives { - transcript: "hi".to_string(), - words: Vec::new(), - confidence: 1.0, - languages: Vec::new(), - }], - }, - metadata: Metadata { - request_id: "r".to_string(), - model_info: ModelInfo { - name: "".to_string(), - version: "".to_string(), - arch: "".to_string(), - }, - model_uuid: "m".to_string(), - extra: None, + fn completion_event_result() { + let event = BatchStreamEvent::Result { + response: owhisper_interface::batch::Response { + metadata: serde_json::json!({ "duration": 0.1 }), + results: owhisper_interface::batch::Results { channels: vec![] }, }, - channel_index: vec![0, 1], }; - assert!(is_completion_response(&response)); + assert!(is_completion_event(&event)); } #[test] - fn completion_response_terminal() { - let response = StreamResponse::TerminalResponse { + fn completion_event_terminal() { + let event = BatchStreamEvent::Terminal { request_id: "r".to_string(), created: "now".to_string(), duration: 1.0, channels: 1, }; - assert!(is_completion_response(&response)); + assert!(is_completion_event(&event)); } #[test] fn provider_error_extracts_fields() { - let response = StreamResponse::ErrorResponse { + let event = BatchStreamEvent::Error { error_code: Some(42), error_message: "nope".to_string(), provider: "x".to_string(), }; - let extracted = provider_error_from_response(&response); + let extracted = provider_error_from_event(&event); assert_eq!(extracted, Some(("x", "nope", Some(42)))); } } diff --git a/crates/listener2-core/src/batch/bootstrap.rs b/crates/listener2-core/src/batch/bootstrap.rs index 6bc24c7f95..4db49400cd 100644 --- a/crates/listener2-core/src/batch/bootstrap.rs +++ b/crates/listener2-core/src/batch/bootstrap.rs @@ -4,7 +4,7 @@ use std::time::Duration; use owhisper_client::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, CactusAdapter, DashScopeAdapter, DeepgramAdapter, ElevenLabsAdapter, FireworksAdapter, GladiaAdapter, HyprnoteAdapter, - MistralAdapter, OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, + MistralAdapter, OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, WhisperCppAdapter, }; use owhisper_interface::{ControlMessage, MixedMessage}; use ractor::{ActorProcessingErr, ActorRef}; @@ -30,6 +30,16 @@ pub(super) async fn spawn_batch_task( ), ActorProcessingErr, > { + use super::BatchProvider; + + if matches!(args.provider, BatchProvider::WhisperLocal) { + return spawn_whispercpp_batch_task(args, myself).await; + } + + if matches!(args.provider, BatchProvider::Cactus) { + return spawn_cactus_batch_task(args, myself).await; + } + let adapter_kind = AdapterKind::from_url_and_languages( &args.base_url, &args.listen_params.languages, @@ -171,6 +181,57 @@ async fn spawn_cactus_batch_task( Ok((rx_task, shutdown_tx)) } +async fn spawn_whispercpp_batch_task( + args: BatchArgs, + myself: ActorRef, +) -> Result< + ( + tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, + ), + ActorProcessingErr, +> { + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + + let span = tracing::info_span!( + "whispercpp_batch", + hyprnote.session.id = %args.session_id, + url.full = %args.base_url, + hyprnote.file.path = %args.file_path, + ); + + let rx_task = tokio::spawn( + async move { + let stream = match WhisperCppAdapter::transcribe_file_streaming( + &args.base_url, + &args.listen_params, + &args.file_path, + ) + .await + { + Ok(stream) => { + notify_start_result(&args.start_notifier, Ok(())); + stream + } + Err(err) => { + report_stream_start_failure( + &myself, + &args.start_notifier, + &err, + "whispercpp batch failed to start stream", + ); + return; + } + }; + + process_provider_stream(stream, myself, shutdown_rx, "whispercpp batch").await; + } + .instrument(span), + ); + + Ok((rx_task, shutdown_tx)) +} + async fn spawn_batch_task_with_adapter( args: BatchArgs, myself: ActorRef, diff --git a/crates/listener2-core/src/events.rs b/crates/listener2-core/src/events.rs index 0f5254dd13..333e2ffbba 100644 --- a/crates/listener2-core/src/events.rs +++ b/crates/listener2-core/src/events.rs @@ -1,5 +1,5 @@ use owhisper_interface::batch::Response as BatchResponse; -use owhisper_interface::stream::StreamResponse; +use owhisper_interface::batch_stream::BatchStreamEvent; use crate::BatchRunMode; @@ -38,8 +38,7 @@ pub enum BatchEvent { #[serde(rename = "batchProgress")] BatchResponseStreamed { session_id: String, - response: StreamResponse, - percentage: f64, + event: BatchStreamEvent, }, #[serde(rename = "batchFailed")] BatchFailed { diff --git a/crates/model-manager/src/manager.rs b/crates/model-manager/src/manager.rs index 4443a1a1eb..d35a519fdf 100644 --- a/crates/model-manager/src/manager.rs +++ b/crates/model-manager/src/manager.rs @@ -81,10 +81,10 @@ impl ModelManager { let mut active = self.active.lock().await; - if let Some(ref a) = *active { - if a.name == resolved { - return Ok(Arc::clone(&a.model)); - } + if let Some(ref a) = *active + && a.name == resolved + { + return Ok(Arc::clone(&a.model)); } *active = None; @@ -129,10 +129,10 @@ impl ModelManager { _ = shutdown_rx.changed() => break, _ = interval.tick() => { let last = last_activity.lock().await; - if let Some(t) = *last { - if t.elapsed() > inactivity_timeout { - *active.lock().await = None; - } + if let Some(t) = *last + && t.elapsed() > inactivity_timeout + { + *active.lock().await = None; } } } diff --git a/crates/owhisper-client/src/adapter/argmax/batch.rs b/crates/owhisper-client/src/adapter/argmax/batch.rs index 7c2df2524e..9b1745accc 100644 --- a/crates/owhisper-client/src/adapter/argmax/batch.rs +++ b/crates/owhisper-client/src/adapter/argmax/batch.rs @@ -4,6 +4,7 @@ use std::time::Duration; use futures_util::StreamExt; use hypr_audio_utils::{Source, f32_to_i16_bytes, resample_audio, source_from_path}; use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, ListenParams, MixedMessage}; use tokio_stream::StreamExt as TokioStreamExt; @@ -154,7 +155,7 @@ impl StreamingBatchConfig { } } -pub use crate::adapter::{StreamingBatchEvent, StreamingBatchStream}; +pub use crate::adapter::StreamingBatchStream; impl ArgmaxAdapter { pub async fn transcribe_file_streaming>( @@ -216,11 +217,7 @@ impl ArgmaxAdapter { result .map(|response| { let percentage = compute_percentage(&response, audio_duration_secs); - StreamingBatchEvent { - response, - percentage, - final_batch_response: None, - } + to_batch_stream_event(response, percentage) }) .map_err(|e| Error::WebSocket(format!("{:?}", e))) }); @@ -229,6 +226,39 @@ impl ArgmaxAdapter { } } +fn to_batch_stream_event(response: StreamResponse, percentage: f64) -> BatchStreamEvent { + match response { + StreamResponse::TranscriptResponse { .. } => BatchStreamEvent::Segment { + response, + percentage, + }, + StreamResponse::TerminalResponse { + request_id, + created, + duration, + channels, + } => BatchStreamEvent::Terminal { + request_id, + created, + duration, + channels, + }, + StreamResponse::ErrorResponse { + error_code, + error_message, + provider, + } => BatchStreamEvent::Error { + error_code, + error_message, + provider, + }, + other => BatchStreamEvent::Segment { + response: other, + percentage, + }, + } +} + fn compute_percentage(response: &StreamResponse, audio_duration_secs: f64) -> f64 { let transcript_end = transcript_end_from_response(response); match transcript_end { diff --git a/crates/owhisper-client/src/adapter/assemblyai/batch.rs b/crates/owhisper-client/src/adapter/assemblyai/batch.rs index 943f186f7c..20335f8a62 100644 --- a/crates/owhisper-client/src/adapter/assemblyai/batch.rs +++ b/crates/owhisper-client/src/adapter/assemblyai/batch.rs @@ -252,12 +252,20 @@ impl AssemblyAIAdapter { .parse::() .ok() }); + let channel = w + .channel + .as_deref() + .and_then(|s| s.parse::().ok()) + .map(|channel| channel.saturating_sub(1)) + .unwrap_or(0) + .max(0); BatchWord { word: w.text.clone(), start: w.start as f64 / 1000.0, end: w.end as f64 / 1000.0, confidence: w.confidence, + channel, speaker, punctuated_word: Some(w.text), } @@ -324,6 +332,50 @@ mod tests { use super::*; use crate::http_client::create_client; + #[test] + fn multichannel_words_are_normalized_to_zero_based_channels() { + let response = TranscriptResponse { + id: "id".to_string(), + status: "completed".to_string(), + text: None, + words: Some(vec![ + AssemblyAIBatchWord { + text: "left".to_string(), + start: 0, + end: 500, + confidence: 0.9, + speaker: Some("1A".to_string()), + channel: Some("1".to_string()), + }, + AssemblyAIBatchWord { + text: "right".to_string(), + start: 500, + end: 1000, + confidence: 0.8, + speaker: Some("2A".to_string()), + channel: Some("2".to_string()), + }, + ]), + utterances: None, + confidence: Some(0.85), + audio_duration: Some(1), + audio_channels: Some(2), + error: None, + }; + + let result = AssemblyAIAdapter::convert_to_batch_response(response); + + assert_eq!(result.results.channels.len(), 2); + assert_eq!( + result.results.channels[0].alternatives[0].words[0].channel, + 0 + ); + assert_eq!( + result.results.channels[1].alternatives[0].words[0].channel, + 1 + ); + } + #[tokio::test] #[ignore] async fn test_assemblyai_batch_transcription() { diff --git a/crates/owhisper-client/src/adapter/cactus/batch.rs b/crates/owhisper-client/src/adapter/cactus/batch.rs index c4e9f0741b..9efb004103 100644 --- a/crates/owhisper-client/src/adapter/cactus/batch.rs +++ b/crates/owhisper-client/src/adapter/cactus/batch.rs @@ -3,6 +3,7 @@ use std::path::{Path, PathBuf}; use futures_util::StreamExt; use owhisper_interface::ListenParams; use owhisper_interface::batch_sse::{BatchSseMessage, EVENT_NAME as BATCH_EVENT}; +use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::progress::InferenceProgress; use owhisper_interface::stream::StreamResponse; @@ -264,28 +265,9 @@ impl SseParserState { ) -> Option> { self.last_percentage = self.last_percentage.max(progress.percentage); - let response = StreamResponse::TranscriptResponse { - start: 0.0, - duration: self.audio_duration_secs * progress.percentage, - is_final: false, - speech_final: false, - from_finalize: false, - channel: owhisper_interface::stream::Channel { - alternatives: vec![owhisper_interface::stream::Alternatives { - transcript: progress.partial_text.clone().unwrap_or_default(), - languages: vec![], - words: vec![], - confidence: 0.0, - }], - }, - metadata: owhisper_interface::stream::Metadata::default(), - channel_index: vec![0, 1], - }; - - Some(Ok(StreamingBatchEvent { - response, + Some(Ok(BatchStreamEvent::Progress { percentage: progress.percentage, - final_batch_response: None, + partial_text: progress.partial_text, })) } @@ -307,53 +289,46 @@ impl SseParserState { }; self.last_percentage = self.last_percentage.max(percentage); - Some(Ok(StreamingBatchEvent { - response, - percentage: self.last_percentage, - final_batch_response: None, - })) + let event = match response { + StreamResponse::TranscriptResponse { .. } => BatchStreamEvent::Segment { + response, + percentage: self.last_percentage, + }, + StreamResponse::TerminalResponse { + request_id, + created, + duration, + channels, + } => BatchStreamEvent::Terminal { + request_id, + created, + duration, + channels, + }, + StreamResponse::ErrorResponse { + error_code, + error_message, + provider, + } => BatchStreamEvent::Error { + error_code, + error_message, + provider, + }, + other => BatchStreamEvent::Segment { + response: other, + percentage: self.last_percentage, + }, + }; + + Some(Ok(event)) } fn handle_result( &mut self, batch_response: owhisper_interface::batch::Response, ) -> Option> { - let transcript = batch_response - .results - .channels - .first() - .and_then(|c| c.alternatives.first()) - .map(|a| a.transcript.clone()) - .unwrap_or_default(); - - let duration = batch_response - .metadata - .get("duration") - .and_then(|v| v.as_f64()) - .unwrap_or(self.audio_duration_secs); - - let response = StreamResponse::TranscriptResponse { - start: 0.0, - duration, - is_final: true, - speech_final: true, - from_finalize: true, - channel: owhisper_interface::stream::Channel { - alternatives: vec![owhisper_interface::stream::Alternatives { - transcript, - languages: vec![], - words: vec![], - confidence: 0.0, - }], - }, - metadata: owhisper_interface::stream::Metadata::default(), - channel_index: vec![0, 1], - }; - - Some(Ok(StreamingBatchEvent { - response, - percentage: 1.0, - final_batch_response: Some(batch_response), + Some(Ok(BatchStreamEvent::Result { + response: batch_response, })) } } diff --git a/crates/owhisper-client/src/adapter/deepgram/batch.rs b/crates/owhisper-client/src/adapter/deepgram/batch.rs index 7b0c59893a..e324e691c1 100644 --- a/crates/owhisper-client/src/adapter/deepgram/batch.rs +++ b/crates/owhisper-client/src/adapter/deepgram/batch.rs @@ -1,7 +1,11 @@ use std::path::{Path, PathBuf}; use owhisper_interface::ListenParams; -use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::batch::{ + Alternatives as BatchAlternatives, Channel as BatchChannel, Response as BatchResponse, + Results as BatchResults, Word as BatchWord, +}; +use serde::Deserialize; use crate::adapter::deepgram_compat::build_batch_url; use crate::adapter::{BatchFuture, BatchSttAdapter, ClientWithMiddleware}; @@ -70,7 +74,8 @@ async fn do_transcribe_file( let status = response.status(); if status.is_success() { - Ok(response.json().await?) + let legacy: DeepgramBatchResponse = response.json().await?; + Ok(convert_response(legacy)) } else { Err(Error::UnexpectedStatus { status, @@ -79,10 +84,180 @@ async fn do_transcribe_file( } } +#[derive(Debug, Deserialize)] +struct DeepgramBatchResponse { + metadata: serde_json::Value, + results: DeepgramBatchResults, +} + +#[derive(Debug, Deserialize)] +struct DeepgramBatchResults { + channels: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeepgramBatchChannel { + alternatives: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeepgramBatchAlternatives { + transcript: String, + confidence: f64, + #[serde(default)] + words: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeepgramBatchWord { + word: String, + start: f64, + end: f64, + confidence: f64, + #[serde(default)] + speaker: Option, + #[serde(default)] + punctuated_word: Option, +} + +fn convert_response(response: DeepgramBatchResponse) -> BatchResponse { + let channels = response + .results + .channels + .into_iter() + .enumerate() + .map(|(channel_idx, channel)| BatchChannel { + alternatives: channel + .alternatives + .into_iter() + .map(|alt| BatchAlternatives { + transcript: alt.transcript, + confidence: alt.confidence, + words: alt + .words + .into_iter() + .map(|word| BatchWord { + word: word.word, + start: word.start, + end: word.end, + confidence: word.confidence, + channel: channel_idx as i32, + speaker: word.speaker, + punctuated_word: word.punctuated_word, + }) + .collect(), + }) + .collect(), + }) + .collect(); + + BatchResponse { + metadata: response.metadata, + results: BatchResults { channels }, + } +} + #[cfg(test)] mod tests { use super::*; + use crate::adapter::deepgram_compat::{ + KeywordQueryStrategy, LanguageQueryStrategy, TranscriptionMode, + }; use crate::http_client::create_client; + use url::UrlQuery; + use url::form_urlencoded::Serializer; + + struct NoLanguageStrategy; + + impl LanguageQueryStrategy for NoLanguageStrategy { + fn append_language_query<'a>( + &self, + _query_pairs: &mut Serializer<'a, UrlQuery>, + _params: &ListenParams, + _mode: TranscriptionMode, + ) { + } + } + + struct NoKeywordStrategy; + + impl KeywordQueryStrategy for NoKeywordStrategy { + fn append_keyword_query<'a>( + &self, + _query_pairs: &mut Serializer<'a, UrlQuery>, + _params: &ListenParams, + ) { + } + } + + #[test] + fn preserves_channel_identity_for_multichannel_batch_words() { + let response = DeepgramBatchResponse { + metadata: serde_json::json!({ "channels": 2 }), + results: DeepgramBatchResults { + channels: vec![ + DeepgramBatchChannel { + alternatives: vec![DeepgramBatchAlternatives { + transcript: "left".to_string(), + confidence: 0.9, + words: vec![DeepgramBatchWord { + word: "left".to_string(), + start: 0.0, + end: 1.0, + confidence: 0.9, + speaker: None, + punctuated_word: Some("left".to_string()), + }], + }], + }, + DeepgramBatchChannel { + alternatives: vec![DeepgramBatchAlternatives { + transcript: "right".to_string(), + confidence: 0.8, + words: vec![DeepgramBatchWord { + word: "right".to_string(), + start: 0.0, + end: 1.0, + confidence: 0.8, + speaker: None, + punctuated_word: Some("right".to_string()), + }], + }], + }, + ], + }, + }; + + let converted = convert_response(response); + + assert_eq!(converted.results.channels.len(), 2); + assert_eq!( + converted.results.channels[0].alternatives[0].words[0].channel, + 0 + ); + assert_eq!( + converted.results.channels[1].alternatives[0].words[0].channel, + 1 + ); + } + + #[test] + fn batch_url_enables_multichannel_for_stereo_audio() { + let params = ListenParams { + channels: 2, + ..Default::default() + }; + + let url = build_batch_url( + "https://api.deepgram.com/v1", + ¶ms, + &NoLanguageStrategy, + &NoKeywordStrategy, + ); + + let query = url.query().unwrap_or_default(); + assert!(query.contains("multichannel=true")); + } #[tokio::test] #[ignore] diff --git a/crates/owhisper-client/src/adapter/deepgram_compat/mod.rs b/crates/owhisper-client/src/adapter/deepgram_compat/mod.rs index bd377889ec..838e0d7eab 100644 --- a/crates/owhisper-client/src/adapter/deepgram_compat/mod.rs +++ b/crates/owhisper-client/src/adapter/deepgram_compat/mod.rs @@ -139,7 +139,7 @@ where builder .add("model", model) .add_bool("diarize", true) - .add_bool("multichannel", false) + .add_bool("multichannel", params.channels > 1) .add_bool("punctuate", true) .add_bool("smart_format", true) .add_bool("utterances", true) diff --git a/crates/owhisper-client/src/adapter/elevenlabs/batch.rs b/crates/owhisper-client/src/adapter/elevenlabs/batch.rs index e2f619d438..2e43e231d7 100644 --- a/crates/owhisper-client/src/adapter/elevenlabs/batch.rs +++ b/crates/owhisper-client/src/adapter/elevenlabs/batch.rs @@ -131,6 +131,7 @@ impl ElevenLabsAdapter { start: w.start, end: w.end, confidence: 1.0, + channel: 0, speaker, punctuated_word: Some(w.text.clone()), } diff --git a/crates/owhisper-client/src/adapter/fireworks/batch.rs b/crates/owhisper-client/src/adapter/fireworks/batch.rs index c6281bc149..af9c8a21fe 100644 --- a/crates/owhisper-client/src/adapter/fireworks/batch.rs +++ b/crates/owhisper-client/src/adapter/fireworks/batch.rs @@ -109,6 +109,7 @@ impl FireworksAdapter { start: w.start, end: w.end, confidence: 1.0, + channel: 0, speaker: None, punctuated_word: Some(w.word), }) diff --git a/crates/owhisper-client/src/adapter/gladia/batch.rs b/crates/owhisper-client/src/adapter/gladia/batch.rs index f1eb8245f6..fa739ecf27 100644 --- a/crates/owhisper-client/src/adapter/gladia/batch.rs +++ b/crates/owhisper-client/src/adapter/gladia/batch.rs @@ -312,6 +312,7 @@ impl GladiaAdapter { start: w.start, end: w.end, confidence: w.confidence, + channel: u.channel as i32, speaker: u.speaker, punctuated_word: Some(trimmed), } diff --git a/crates/owhisper-client/src/adapter/mistral/batch.rs b/crates/owhisper-client/src/adapter/mistral/batch.rs index c037ad6288..ef61962d80 100644 --- a/crates/owhisper-client/src/adapter/mistral/batch.rs +++ b/crates/owhisper-client/src/adapter/mistral/batch.rs @@ -165,6 +165,7 @@ fn convert_response(response: MistralBatchResponse) -> BatchResponse { start: w.start, end: w.end, confidence: 1.0, + channel: 0, speaker: None, punctuated_word: Some(w.word), } @@ -199,6 +200,7 @@ fn convert_response(response: MistralBatchResponse) -> BatchResponse { start: word_start, end: word_end, confidence: 1.0, + channel: 0, speaker: None, punctuated_word: Some(w.to_string()), } diff --git a/crates/owhisper-client/src/adapter/mod.rs b/crates/owhisper-client/src/adapter/mod.rs index f179ef097f..8043fd20f2 100644 --- a/crates/owhisper-client/src/adapter/mod.rs +++ b/crates/owhisper-client/src/adapter/mod.rs @@ -1,3 +1,6 @@ +pub mod parsing; +mod url_builder; + mod argmax; pub(crate) mod assemblyai; mod cactus; @@ -13,9 +16,8 @@ mod language; mod mistral; mod openai; mod owhisper; -pub mod parsing; pub(crate) mod soniox; -mod url_builder; +mod whispercpp; pub use argmax::*; pub use assemblyai::*; @@ -30,6 +32,7 @@ pub use language::{LanguageQuality, LanguageSupport}; pub use mistral::*; pub use openai::*; pub use soniox::*; +pub use whispercpp::*; use std::collections::{BTreeSet, HashSet}; use std::future::Future; @@ -39,6 +42,7 @@ use std::pin::Pin; use hypr_ws_client::client::Message; use owhisper_interface::ListenParams; use owhisper_interface::batch::Response as BatchResponse; +use owhisper_interface::batch_stream::BatchStreamEvent; use owhisper_interface::stream::StreamResponse; use crate::error::Error; @@ -47,15 +51,10 @@ pub use reqwest_middleware::ClientWithMiddleware; pub type BatchFuture<'a> = Pin> + Send + 'a>>; -#[derive(Debug, Clone)] -pub struct StreamingBatchEvent { - pub response: StreamResponse, - pub percentage: f64, - pub final_batch_response: Option, -} +pub type StreamingBatchEvent = BatchStreamEvent; pub type StreamingBatchStream = - Pin> + Send>>; + Pin> + Send>>; pub fn documented_language_codes_live() -> Vec { let mut set: BTreeSet<&'static str> = BTreeSet::new(); diff --git a/crates/owhisper-client/src/adapter/openai/batch.rs b/crates/owhisper-client/src/adapter/openai/batch.rs index d6e1ba8dfb..0ab2dde41e 100644 --- a/crates/owhisper-client/src/adapter/openai/batch.rs +++ b/crates/owhisper-client/src/adapter/openai/batch.rs @@ -173,6 +173,7 @@ fn convert_response(response: OpenAIVerboseResponse) -> BatchResponse { start: w.start, end: w.end, confidence: 1.0, + channel: 0, speaker: None, punctuated_word: Some(w.word), } diff --git a/crates/owhisper-client/src/adapter/soniox/batch.rs b/crates/owhisper-client/src/adapter/soniox/batch.rs index 175e44abb2..bef7e0fa04 100644 --- a/crates/owhisper-client/src/adapter/soniox/batch.rs +++ b/crates/owhisper-client/src/adapter/soniox/batch.rs @@ -111,6 +111,7 @@ impl SonioxAdapter { start: token.start_ms.unwrap_or(0) as f64 / 1000.0, end: token.end_ms.unwrap_or(0) as f64 / 1000.0, confidence: token.confidence.unwrap_or(1.0), + channel: 0, speaker: token.speaker.as_ref().and_then(|s| s.as_usize()), punctuated_word: Some(token.text.clone()), }) diff --git a/crates/owhisper-client/src/adapter/soniox/live.rs b/crates/owhisper-client/src/adapter/soniox/live.rs index 8f448c6396..58841a52f1 100644 --- a/crates/owhisper-client/src/adapter/soniox/live.rs +++ b/crates/owhisper-client/src/adapter/soniox/live.rs @@ -298,9 +298,9 @@ fn build_words(tokens: &[&soniox::Token]) -> Vec( - tokens: &'a [soniox::Token], -) -> (Vec<&'a soniox::Token>, Vec<&'a soniox::Token>) { +fn partition_tokens_by_word_finality( + tokens: &[soniox::Token], +) -> (Vec<&soniox::Token>, Vec<&soniox::Token>) { let mut final_tokens = Vec::new(); let mut non_final_tokens = Vec::new(); for group in token_groups_from_values(tokens) { diff --git a/crates/owhisper-client/src/adapter/whispercpp/batch.rs b/crates/owhisper-client/src/adapter/whispercpp/batch.rs new file mode 100644 index 0000000000..5cd4c6b608 --- /dev/null +++ b/crates/owhisper-client/src/adapter/whispercpp/batch.rs @@ -0,0 +1,344 @@ +use std::path::{Path, PathBuf}; + +use futures_util::StreamExt; +use owhisper_interface::ListenParams; +use owhisper_interface::batch_sse::{BatchSseMessage, EVENT_NAME as BATCH_EVENT}; +use owhisper_interface::batch_stream::BatchStreamEvent; +use owhisper_interface::progress::InferenceProgress; +use owhisper_interface::stream::StreamResponse; + +use crate::adapter::{StreamingBatchEvent, StreamingBatchStream}; +use crate::error::Error; + +use super::WhisperCppAdapter; + +impl WhisperCppAdapter { + pub async fn transcribe_file_streaming( + api_base: &str, + params: &ListenParams, + file_path: impl AsRef, + ) -> Result { + let path = file_path.as_ref().to_path_buf(); + tracing::info!( + hyprnote.file.path = %path.display(), + url.full = %api_base, + "starting_whispercpp_batch_stream" + ); + + let (audio_data, content_type, audio_duration_secs) = + tokio::task::spawn_blocking(move || load_audio_file(path)) + .await + .map_err(|e| Error::AudioProcessing(format!("task panicked: {:?}", e)))??; + + let url = build_batch_url(api_base, params); + + let response = reqwest::Client::new() + .post(url.as_str()) + .header("Content-Type", &content_type) + .header("Accept", "text/event-stream") + .body(audio_data) + .send() + .await?; + + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { status, body }); + } + + let byte_stream = response.bytes_stream(); + + let event_stream = futures_util::stream::unfold( + SseParserState::new(byte_stream, audio_duration_secs), + |mut state| async move { + loop { + if let Some(event) = state.pending_events.pop_front() { + return Some((event, state)); + } + + match state.stream.next().await { + Some(Ok(chunk)) => { + state.buffer.extend_from_slice(&chunk); + state.parse_buffer(); + } + Some(Err(e)) => { + return Some(( + Err(Error::WebSocket(format!("stream error: {:?}", e))), + state, + )); + } + None => { + if !state.buffer.is_empty() { + state.parse_buffer(); + if let Some(event) = state.pending_events.pop_front() { + return Some((event, state)); + } + } + return None; + } + } + } + }, + ); + + Ok(Box::pin(event_stream)) + } +} + +fn load_audio_file(path: PathBuf) -> Result<(Vec, String, f64), Error> { + let data = + std::fs::read(&path).map_err(|e| Error::AudioProcessing(format!("read failed: {e}")))?; + + let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("wav"); + let content_type = match extension { + "wav" => "audio/wav", + "mp3" => "audio/mpeg", + "ogg" => "audio/ogg", + "flac" => "audio/flac", + "m4a" => "audio/mp4", + "webm" => "audio/webm", + _ => "application/octet-stream", + } + .to_string(); + + let duration = audio_duration_secs(&path); + + Ok((data, content_type, duration)) +} + +fn audio_duration_secs(path: &Path) -> f64 { + use hypr_audio_utils::Source; + let Ok(source) = hypr_audio_utils::source_from_path(path) else { + return 0.0; + }; + if let Some(d) = source.total_duration() { + return d.as_secs_f64(); + } + let sample_rate = u32::from(source.sample_rate()) as f64; + let channels = u16::from(source.channels()).max(1) as f64; + let count = source.count() as f64; + count / channels / sample_rate +} + +fn build_batch_url(api_base: &str, params: &ListenParams) -> url::Url { + let mut url: url::Url = api_base.parse().expect("invalid api_base URL"); + + if !url.path().ends_with("/listen") { + let path = url.path().trim_end_matches('/').to_string(); + url.set_path(&format!("{}/listen", path)); + } + + for lang in ¶ms.languages { + url.query_pairs_mut() + .append_pair("language", lang.iso639().code()); + } + if !params.keywords.is_empty() { + for kw in ¶ms.keywords { + url.query_pairs_mut().append_pair("keywords", kw); + } + } + if let Some(ref model) = params.model { + url.query_pairs_mut().append_pair("model", model); + } + + url +} + +struct SseParserState { + stream: S, + buffer: Vec, + pending_events: std::collections::VecDeque>, + audio_duration_secs: f64, + last_percentage: f64, +} + +impl SseParserState { + fn new(stream: S, audio_duration_secs: f64) -> Self { + Self { + stream, + buffer: Vec::new(), + pending_events: std::collections::VecDeque::new(), + audio_duration_secs, + last_percentage: 0.0, + } + } + + fn parse_buffer(&mut self) { + while let Ok(text) = std::str::from_utf8(&self.buffer) { + let Some(end) = text.find("\n\n") else { + break; + }; + + let block = text[..end].to_string(); + self.buffer.drain(..end + 2); + + if let Some(event) = self.parse_sse_block(&block) { + self.pending_events.push_back(event); + } + } + } + + fn parse_sse_block(&mut self, block: &str) -> Option> { + let mut event_type = String::new(); + let mut data = String::new(); + + for line in block.lines() { + if let Some(rest) = line.strip_prefix("event:") { + event_type = rest.trim().to_string(); + } else if let Some(rest) = line.strip_prefix("data:") { + if !data.is_empty() { + data.push('\n'); + } + data.push_str(rest.trim()); + } else if line.starts_with(':') { + // comment, skip + } + } + + if data.is_empty() { + return None; + } + + match event_type.as_str() { + BATCH_EVENT => { + let msg: BatchSseMessage = match serde_json::from_str(&data) { + Ok(m) => m, + Err(e) => { + tracing::warn!( + raw_data = %data, + "failed to parse batch SSE event: {e}" + ); + return None; + } + }; + + match msg { + BatchSseMessage::Progress { progress } => self.handle_progress(progress), + BatchSseMessage::Segment { response } => self.handle_segment(response), + BatchSseMessage::Result { response } => self.handle_result(response), + BatchSseMessage::Error { detail, .. } => { + tracing::error!(detail = %detail, "server returned error event"); + Some(Err(Error::WebSocket(format!("server error: {}", detail)))) + } + } + } + "progress" => { + let progress: InferenceProgress = match serde_json::from_str(&data) { + Ok(p) => p, + Err(e) => { + tracing::warn!(raw_data = %data, "failed to parse progress event: {e}"); + return None; + } + }; + self.handle_progress(progress) + } + "segment" => { + let response: StreamResponse = match serde_json::from_str(&data) { + Ok(r) => r, + Err(e) => { + tracing::warn!(raw_data = %data, "failed to parse segment event: {e}"); + return None; + } + }; + self.handle_segment(response) + } + "result" => { + let batch_response: owhisper_interface::batch::Response = + match serde_json::from_str(&data) { + Ok(r) => r, + Err(e) => { + tracing::error!(raw_data = %data, "failed to parse result event: {e}"); + return Some(Err(Error::WebSocket(format!( + "failed to parse result: {e}" + )))); + } + }; + + self.handle_result(batch_response) + } + "error" => { + let error_data: serde_json::Value = serde_json::from_str(&data).unwrap_or_default(); + let detail = error_data + .get("detail") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + tracing::error!(detail = %detail, raw_data = %data, "server returned error event"); + Some(Err(Error::WebSocket(format!("server error: {}", detail)))) + } + _ => None, + } + } + + fn handle_progress( + &mut self, + progress: InferenceProgress, + ) -> Option> { + self.last_percentage = self.last_percentage.max(progress.percentage); + + Some(Ok(BatchStreamEvent::Progress { + percentage: progress.percentage, + partial_text: progress.partial_text, + })) + } + + fn handle_segment( + &mut self, + response: StreamResponse, + ) -> Option> { + let segment_end = match &response { + StreamResponse::TranscriptResponse { + start, duration, .. + } => start + duration, + _ => 0.0, + }; + + let percentage = if self.audio_duration_secs > 0.0 { + (segment_end / self.audio_duration_secs).clamp(0.0, 1.0) + } else { + 0.0 + }; + self.last_percentage = self.last_percentage.max(percentage); + + let event = match response { + StreamResponse::TranscriptResponse { .. } => BatchStreamEvent::Segment { + response, + percentage: self.last_percentage, + }, + StreamResponse::TerminalResponse { + request_id, + created, + duration, + channels, + } => BatchStreamEvent::Terminal { + request_id, + created, + duration, + channels, + }, + StreamResponse::ErrorResponse { + error_code, + error_message, + provider, + } => BatchStreamEvent::Error { + error_code, + error_message, + provider, + }, + other => BatchStreamEvent::Segment { + response: other, + percentage: self.last_percentage, + }, + }; + + Some(Ok(event)) + } + + fn handle_result( + &mut self, + batch_response: owhisper_interface::batch::Response, + ) -> Option> { + Some(Ok(BatchStreamEvent::Result { + response: batch_response, + })) + } +} diff --git a/crates/owhisper-client/src/adapter/whispercpp/mod.rs b/crates/owhisper-client/src/adapter/whispercpp/mod.rs new file mode 100644 index 0000000000..c1ffc64554 --- /dev/null +++ b/crates/owhisper-client/src/adapter/whispercpp/mod.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "local")] +mod batch; + +#[derive(Clone, Default)] +pub struct WhisperCppAdapter; diff --git a/crates/owhisper-client/src/lib.rs b/crates/owhisper-client/src/lib.rs index 320d7f2a78..c76965cc8f 100644 --- a/crates/owhisper-client/src/lib.rs +++ b/crates/owhisper-client/src/lib.rs @@ -21,7 +21,7 @@ pub use adapter::{ AdapterKind, ArgmaxAdapter, AssemblyAIAdapter, BatchSttAdapter, CactusAdapter, CallbackResult, CallbackSttAdapter, DashScopeAdapter, DeepgramAdapter, ElevenLabsAdapter, FireworksAdapter, GladiaAdapter, HyprnoteAdapter, LanguageQuality, LanguageSupport, MistralAdapter, - OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, append_provider_param, + OpenAIAdapter, RealtimeSttAdapter, SonioxAdapter, WhisperCppAdapter, append_provider_param, documented_language_codes_batch, documented_language_codes_live, is_hyprnote_proxy, is_local_host, normalize_languages, }; diff --git a/crates/owhisper-interface/src/batch.rs b/crates/owhisper-interface/src/batch.rs index 4e9b080fac..6c455a871b 100644 --- a/crates/owhisper-interface/src/batch.rs +++ b/crates/owhisper-interface/src/batch.rs @@ -12,6 +12,8 @@ common_derives! { pub start: f64, pub end: f64, pub confidence: f64, + #[serde(default)] + pub channel: i32, pub speaker: Option, pub punctuated_word: Option, } @@ -61,6 +63,7 @@ impl From for Word { start: word.start, end: word.end, confidence: word.confidence, + channel: 0, speaker: word .speaker .and_then(|speaker| (speaker >= 0).then_some(speaker as usize)), @@ -103,3 +106,40 @@ impl From for serde_json::Value { serde_json::to_value(metadata).unwrap_or_else(|_| serde_json::json!({})) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_batch_word_defaults_missing_channel_to_zero() { + let response: Response = serde_json::from_value(serde_json::json!({ + "metadata": {}, + "results": { + "channels": [ + { + "alternatives": [ + { + "transcript": "hello", + "confidence": 0.9, + "words": [ + { + "word": "hello", + "start": 0.0, + "end": 1.0, + "confidence": 0.9 + } + ] + } + ] + } + ] + } + })) + .unwrap(); + + let word = &response.results.channels[0].alternatives[0].words[0]; + assert_eq!(word.channel, 0); + assert_eq!(word.word, "hello"); + } +} diff --git a/crates/owhisper-interface/src/batch_sse.rs b/crates/owhisper-interface/src/batch_sse.rs index 2518f5fa43..3b7cb27fdc 100644 --- a/crates/owhisper-interface/src/batch_sse.rs +++ b/crates/owhisper-interface/src/batch_sse.rs @@ -1,4 +1,4 @@ -use crate::{InferenceProgress, batch, common_derives, stream}; +use crate::{InferenceProgress, batch, batch_stream, common_derives, stream}; pub const EVENT_NAME: &str = "batch"; @@ -11,3 +11,26 @@ common_derives! { Error { error: String, detail: String }, } } + +impl From for batch_stream::BatchStreamEvent { + fn from(value: BatchSseMessage) -> Self { + match value { + BatchSseMessage::Progress { progress } => batch_stream::BatchStreamEvent::Progress { + percentage: progress.percentage, + partial_text: progress.partial_text, + }, + BatchSseMessage::Segment { response } => batch_stream::BatchStreamEvent::Segment { + percentage: 0.0, + response, + }, + BatchSseMessage::Result { response } => { + batch_stream::BatchStreamEvent::Result { response } + } + BatchSseMessage::Error { error, detail } => batch_stream::BatchStreamEvent::Error { + error_code: None, + error_message: detail, + provider: error, + }, + } + } +} diff --git a/crates/owhisper-interface/src/batch_stream.rs b/crates/owhisper-interface/src/batch_stream.rs new file mode 100644 index 0000000000..9368a8c404 --- /dev/null +++ b/crates/owhisper-interface/src/batch_stream.rs @@ -0,0 +1,54 @@ +use crate::{batch, common_derives, stream}; + +common_derives! { + #[serde(tag = "type", rename_all = "snake_case")] + pub enum BatchStreamEvent { + Progress { + percentage: f64, + #[serde(default)] + partial_text: Option, + }, + Segment { + response: stream::StreamResponse, + percentage: f64, + }, + Terminal { + request_id: String, + created: String, + duration: f64, + channels: u32, + }, + Result { + response: batch::Response, + }, + Error { + error_code: Option, + error_message: String, + provider: String, + }, + } +} + +impl BatchStreamEvent { + pub fn percentage(&self) -> f64 { + match self { + Self::Progress { percentage, .. } | Self::Segment { percentage, .. } => *percentage, + Self::Terminal { .. } | Self::Result { .. } => 1.0, + Self::Error { .. } => 0.0, + } + } + + pub fn text(&self) -> Option<&str> { + match self { + Self::Progress { partial_text, .. } => partial_text.as_deref(), + Self::Segment { response, .. } => response.text(), + Self::Result { response } => response + .results + .channels + .first() + .and_then(|channel| channel.alternatives.first()) + .map(|alternative| alternative.transcript.as_str()), + Self::Terminal { .. } | Self::Error { .. } => None, + } + } +} diff --git a/crates/owhisper-interface/src/lib.rs b/crates/owhisper-interface/src/lib.rs index 6daa4f8eeb..dda3d58462 100644 --- a/crates/owhisper-interface/src/lib.rs +++ b/crates/owhisper-interface/src/lib.rs @@ -1,5 +1,6 @@ pub mod batch; pub mod batch_sse; +pub mod batch_stream; #[cfg(feature = "openapi")] pub mod openapi; pub mod progress; diff --git a/crates/owhisper-interface/src/openapi.rs b/crates/owhisper-interface/src/openapi.rs index aacb974d48..b460f206d4 100644 --- a/crates/owhisper-interface/src/openapi.rs +++ b/crates/owhisper-interface/src/openapi.rs @@ -39,6 +39,7 @@ pub struct StreamListenParams { crate::batch::Channel, crate::batch::Alternatives, crate::batch::Word, + crate::batch_stream::BatchStreamEvent, crate::stream::StreamResponse, crate::stream::Channel, crate::stream::Alternatives, diff --git a/crates/transcribe-cactus/src/service/batch/response.rs b/crates/transcribe-cactus/src/service/batch/response.rs index 733904586b..a79b1e54da 100644 --- a/crates/transcribe-cactus/src/service/batch/response.rs +++ b/crates/transcribe-cactus/src/service/batch/response.rs @@ -4,6 +4,7 @@ pub(super) fn build_batch_words( transcript: &str, total_duration: f64, confidence: f64, + channel: i32, ) -> Vec { let word_strs: Vec<&str> = transcript.split_whitespace().collect(); if word_strs.is_empty() || total_duration <= 0.0 { @@ -19,6 +20,7 @@ pub(super) fn build_batch_words( start: i as f64 * word_duration, end: (i + 1) as f64 * word_duration, confidence, + channel, speaker: None, punctuated_word: Some(w.to_string()), }) @@ -93,7 +95,7 @@ mod tests { #[test] fn batch_words_evenly_distributed() { - let words = build_batch_words("hello beautiful world", 3.0, 0.9); + let words = build_batch_words("hello beautiful world", 3.0, 0.9, 0); assert_eq!(words.len(), 3); assert_eq!(words[0].word, "hello"); @@ -117,19 +119,19 @@ mod tests { #[test] fn batch_words_empty_transcript() { - let words = build_batch_words("", 5.0, 0.9); + let words = build_batch_words("", 5.0, 0.9, 0); assert!(words.is_empty()); } #[test] fn batch_words_zero_duration() { - let words = build_batch_words("hello world", 0.0, 0.9); + let words = build_batch_words("hello world", 0.0, 0.9, 0); assert!(words.is_empty()); } #[test] fn batch_response_deepgram_shape() { - let words = build_batch_words("hello world", 2.0, 0.95); + let words = build_batch_words("hello world", 2.0, 0.95, 0); let meta = Metadata { model_info: ModelInfo { name: "test".to_string(), @@ -205,14 +207,14 @@ mod tests { alternatives: vec![batch::Alternatives { transcript: "left".to_string(), confidence: 0.9, - words: build_batch_words("left", 1.0, 0.9), + words: build_batch_words("left", 1.0, 0.9, 0), }], }, batch::Channel { alternatives: vec![batch::Alternatives { transcript: "right".to_string(), confidence: 0.8, - words: build_batch_words("right", 1.0, 0.8), + words: build_batch_words("right", 1.0, 0.8, 1), }], }, ], diff --git a/crates/transcribe-cactus/src/service/batch/transcribe.rs b/crates/transcribe-cactus/src/service/batch/transcribe.rs index b3613f5812..4fe5678215 100644 --- a/crates/transcribe-cactus/src/service/batch/transcribe.rs +++ b/crates/transcribe-cactus/src/service/batch/transcribe.rs @@ -80,7 +80,7 @@ pub(super) fn transcribe_batch( channel_idx, chunks, channel_duration, - &model, + model, &options, &mut progress, &metadata, @@ -114,6 +114,7 @@ pub(super) fn transcribe_batch( }) } +#[allow(clippy::too_many_arguments)] fn transcribe_chunks( channel_idx: usize, chunks: &[hypr_vad_chunking::AudioChunk], @@ -169,6 +170,7 @@ fn transcribe_chunks( &chunk_text, chunk_duration_sec, cactus_response.confidence as f64, + channel_idx as i32, ); for w in &mut words { w.start += chunk_start_sec; diff --git a/crates/transcribe-whisper-local/src/service/batch.rs b/crates/transcribe-whisper-local/src/service/batch.rs index 1c6c9c5f22..0b3d4158fc 100644 --- a/crates/transcribe-whisper-local/src/service/batch.rs +++ b/crates/transcribe-whisper-local/src/service/batch.rs @@ -283,7 +283,7 @@ fn transcribe_chunks( for segment in segments { cumulative_confidence += segment.confidence; segment_count += 1; - all_words.extend(build_batch_words(&segment)); + all_words.extend(build_batch_words(&segment, channel_idx as i32)); if let Some(tx) = progress.event_tx() { let _ = tx.send(BatchSseMessage::Segment { diff --git a/crates/transcribe-whisper-local/src/service/mod.rs b/crates/transcribe-whisper-local/src/service/mod.rs index ffe44d7406..4bffb4a5a0 100644 --- a/crates/transcribe-whisper-local/src/service/mod.rs +++ b/crates/transcribe-whisper-local/src/service/mod.rs @@ -92,39 +92,164 @@ pub(crate) fn transcribe_chunk( samples: &[f32], chunk_start_sec: f64, ) -> Result, crate::Error> { - Ok(model - .transcribe(samples)? + let raw_segments = model.transcribe(samples)?; + let chunk_duration_sec = samples.len() as f64 / TARGET_SAMPLE_RATE as f64; + + Ok(build_chunk_segments( + raw_segments, + chunk_start_sec, + chunk_duration_sec, + )) +} + +fn build_chunk_segments( + raw_segments: Vec, + chunk_start_sec: f64, + chunk_duration_sec: f64, +) -> Vec { + if chunk_duration_sec <= 0.0 { + return vec![]; + } + + let raw_segments: Vec<_> = raw_segments + .into_iter() + .filter_map(|segment| { + let text = segment.text().trim().to_string(); + if text.is_empty() { + return None; + } + + Some(( + segment.start(), + segment.end(), + Segment { + text, + start: 0.0, + duration: 0.0, + confidence: segment.confidence() as f64, + language: segment.language().map(|value| value.to_string()), + }, + )) + }) + .collect(); + + if raw_segments.is_empty() { + return vec![]; + } + + if raw_segments.len() == 1 { + return vec![Segment { + start: chunk_start_sec, + duration: chunk_duration_sec, + ..raw_segments.into_iter().next().unwrap().2 + }]; + } + + let timings = normalize_raw_segment_timings(&raw_segments, chunk_duration_sec) + .unwrap_or_else(|| synthetic_segment_timings(&raw_segments, chunk_duration_sec)); + + raw_segments .into_iter() - .map(|segment| Segment { - text: segment.text().trim().to_string(), - start: chunk_start_sec + segment.start(), - duration: segment.end() - segment.start(), - confidence: segment.confidence() as f64, - language: segment.language().map(|value| value.to_string()), + .zip(timings) + .map(|((_, _, segment), (start_offset, duration))| Segment { + start: chunk_start_sec + start_offset, + duration, + ..segment }) - .filter(|segment| !segment.text.is_empty() && segment.duration > 0.0) - .collect()) + .collect() +} + +fn normalize_raw_segment_timings( + raw_segments: &[(f64, f64, Segment)], + chunk_duration_sec: f64, +) -> Option> { + let mut clamped_bounds = Vec::with_capacity(raw_segments.len()); + let mut previous_end = 0.0; + + for (start, end, _) in raw_segments { + if !start.is_finite() || !end.is_finite() { + return None; + } + + let start = (*start).max(0.0).max(previous_end); + let end = (*end).max(0.0); + if end <= start { + return None; + } + + clamped_bounds.push((start, end)); + previous_end = end; + } + + if previous_end <= 0.0 { + return None; + } + + let scale = chunk_duration_sec / previous_end; + let mut timings = Vec::with_capacity(clamped_bounds.len()); + + for (idx, (start, end)) in clamped_bounds.into_iter().enumerate() { + let start = (start * scale).min(chunk_duration_sec); + let end = if idx + 1 == raw_segments.len() { + chunk_duration_sec + } else { + (end * scale).min(chunk_duration_sec) + }; + + if end <= start { + return None; + } + + timings.push((start, end - start)); + } + + Some(timings) +} + +fn synthetic_segment_timings( + raw_segments: &[(f64, f64, Segment)], + chunk_duration_sec: f64, +) -> Vec<(f64, f64)> { + let total_weight: usize = raw_segments + .iter() + .map(|(_, _, segment)| segment.text.split_whitespace().count().max(1)) + .sum(); + let mut cursor = 0.0; + + raw_segments + .iter() + .enumerate() + .map(|(idx, (_, _, segment))| { + let weight = segment.text.split_whitespace().count().max(1) as f64; + let start = cursor; + let end = if idx + 1 == raw_segments.len() { + chunk_duration_sec + } else { + cursor + chunk_duration_sec * (weight / total_weight as f64) + }; + cursor = end; + (start, end - start) + }) + .collect() } #[cfg(test)] mod tests { - use hypr_language::ISO639; - use super::*; #[test] fn parse_single_language() { let params = parse_listen_params("language=en").unwrap(); assert_eq!(params.languages.len(), 1); - assert_eq!(params.languages[0].iso639(), ISO639::En); + assert_eq!(params.languages[0].iso639().code(), "en"); } #[test] fn parse_multiple_languages() { let params = parse_listen_params("language=en&language=ko").unwrap(); assert_eq!(params.languages.len(), 2); - assert_eq!(params.languages[0].iso639(), ISO639::En); - assert_eq!(params.languages[1].iso639(), ISO639::Ko); + assert_eq!(params.languages[0].iso639().code(), "en"); + assert_eq!(params.languages[1].iso639().code(), "ko"); } #[test] @@ -146,4 +271,75 @@ mod tests { assert_eq!(params.channels, 1); assert_eq!(params.sample_rate, TARGET_SAMPLE_RATE); } + + #[test] + fn preserves_multiple_segments_with_normalized_timings() { + let segments = build_chunk_segments( + vec![ + hypr_whisper_local::Segment { + text: "hello".to_string(), + language: Some("en".to_string()), + start: 0.0, + end: 1.0, + confidence: 0.8, + ..Default::default() + }, + hypr_whisper_local::Segment { + text: "again".to_string(), + language: Some("en".to_string()), + start: 1.5, + end: 2.0, + confidence: 1.0, + ..Default::default() + }, + ], + 10.0, + 4.0, + ); + + assert_eq!(segments.len(), 2); + assert_eq!(segments[0].start, 10.0); + assert_eq!(segments[0].duration, 2.0); + assert_eq!(segments[0].text, "hello"); + assert_eq!(segments[0].language.as_deref(), Some("en")); + assert!((segments[0].confidence - 0.8).abs() < 1e-6); + + assert_eq!(segments[1].start, 13.0); + assert_eq!(segments[1].duration, 1.0); + assert_eq!(segments[1].text, "again"); + assert_eq!(segments[1].language.as_deref(), Some("en")); + assert!((segments[1].confidence - 1.0).abs() < 1e-6); + } + + #[test] + fn falls_back_to_synthetic_timings_when_raw_timings_are_invalid() { + let segments = build_chunk_segments( + vec![ + hypr_whisper_local::Segment { + text: "hello world".to_string(), + language: Some("en".to_string()), + start: 0.0, + end: 0.0, + confidence: 0.8, + ..Default::default() + }, + hypr_whisper_local::Segment { + text: "again".to_string(), + language: Some("en".to_string()), + start: 0.0, + end: 0.0, + confidence: 1.0, + ..Default::default() + }, + ], + 10.0, + 3.0, + ); + + assert_eq!(segments.len(), 2); + assert_eq!(segments[0].start, 10.0); + assert_eq!(segments[0].duration, 2.0); + assert_eq!(segments[1].start, 12.0); + assert_eq!(segments[1].duration, 1.0); + } } diff --git a/crates/transcribe-whisper-local/src/service/response.rs b/crates/transcribe-whisper-local/src/service/response.rs index 42075bbb96..5730a942f4 100644 --- a/crates/transcribe-whisper-local/src/service/response.rs +++ b/crates/transcribe-whisper-local/src/service/response.rs @@ -83,10 +83,16 @@ pub(super) fn build_stream_words(segment: &crate::service::Segment) -> Vec Vec { +pub(super) fn build_batch_words( + segment: &crate::service::Segment, + channel: i32, +) -> Vec { build_stream_words(segment) .into_iter() - .map(batch::Word::from) + .map(|word| batch::Word { + channel, + ..batch::Word::from(word) + }) .collect() } diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index 803fda4662..46dcd46732 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -413,7 +413,6 @@ async fn handle_websocket( } let _ = ws_sender.close().await; - return; } Err(error) => { send_ws_best_effort( @@ -426,7 +425,6 @@ async fn handle_websocket( ) .await; let _ = ws_sender.close().await; - return; } } } @@ -434,6 +432,7 @@ async fn handle_websocket( type TranscriptionStream = Pin> + Send>>; +#[allow(clippy::type_complexity)] fn build_transcription_streams( total_channels: usize, loaded_model: &hypr_whisper_local::LoadedWhisper, diff --git a/crates/transcript/src/words/assembly.rs b/crates/transcript/src/words/assembly.rs index 1051cbad22..65d5830df0 100644 --- a/crates/transcript/src/words/assembly.rs +++ b/crates/transcript/src/words/assembly.rs @@ -220,9 +220,8 @@ mod tests { &mut result, token("look", None, 0, 100), "look".to_string(), - 0, ); - push_assembled_word(&mut result, token(".", None, 100, 110), ".".to_string(), 0); + push_assembled_word(&mut result, token(".", None, 100, 110), ".".to_string()); assert_eq!(result.len(), 1); assert_eq!(result[0].text, "look."); @@ -235,13 +234,11 @@ mod tests { &mut result, token("look", None, 0, 100), "look.".to_string(), - 0, ); push_assembled_word( &mut result, token("Everyone", None, 110, 200), "Everyone".to_string(), - 0, ); assert_eq!(result.len(), 2); @@ -292,12 +289,11 @@ mod tests { #[test] fn merges_contraction_into_previous_word() { let mut result = Vec::new(); - push_assembled_word(&mut result, token("it", None, 0, 100), " it".to_string(), 0); + push_assembled_word(&mut result, token("it", None, 0, 100), " it".to_string()); push_assembled_word( &mut result, token("'s", None, 100, 150), "'s".to_string(), - 0, ); assert_eq!(result.len(), 1); diff --git a/plugins/listener2/js/bindings.gen.ts b/plugins/listener2/js/bindings.gen.ts index 79becf45e8..996b6d6bfe 100644 --- a/plugins/listener2/js/bindings.gen.ts +++ b/plugins/listener2/js/bindings.gen.ts @@ -84,14 +84,15 @@ denoiseEvent: "plugin:listener2:denoise-event" export type BatchAlternatives = { transcript: string; confidence: number; words?: BatchWord[] } export type BatchChannel = { alternatives: BatchAlternatives[] } export type BatchErrorCode = "unknown" | "audio_metadata_join_failed" | "audio_metadata_read_failed" | "provider_request_failed" | "actor_spawn_failed" | "stream_start_cancelled" | "stream_stopped_without_completion_signal" | "stream_finished_without_status" | "stream_start_failed" | "stream_error" | "stream_timeout" -export type BatchEvent = { type: "batchStarted"; session_id: string } | { type: "batchCompleted"; session_id: string } | { type: "batchResponse"; session_id: string; response: BatchResponse; mode: BatchRunMode } | { type: "batchProgress"; session_id: string; response: StreamResponse; percentage: number } | { type: "batchFailed"; session_id: string; code: BatchErrorCode; error: string } +export type BatchEvent = { type: "batchStarted"; session_id: string } | { type: "batchCompleted"; session_id: string } | { type: "batchResponse"; session_id: string; response: BatchResponse; mode: BatchRunMode } | { type: "batchProgress"; session_id: string; event: BatchStreamEvent } | { type: "batchFailed"; session_id: string; code: BatchErrorCode; error: string } export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[] } -export type BatchProvider = "argmax" | "deepgram" | "soniox" | "assemblyai" | "fireworks" | "openai" | "gladia" | "elevenlabs" | "dashscope" | "mistral" | "hyprnote" | "am" | "cactus" +export type BatchProvider = "argmax" | "whispercpp" | "deepgram" | "soniox" | "assemblyai" | "fireworks" | "openai" | "gladia" | "elevenlabs" | "dashscope" | "mistral" | "hyprnote" | "am" | "cactus" export type BatchResponse = { metadata: JsonValue; results: BatchResults } export type BatchResults = { channels: BatchChannel[] } export type BatchRunMode = "direct" | "streamed" export type BatchRunOutput = { session_id: string; mode: BatchRunMode; response: BatchResponse } -export type BatchWord = { word: string; start: number; end: number; confidence: number; speaker: number | null; punctuated_word: string | null } +export type BatchStreamEvent = { type: "progress"; percentage: number; partial_text?: string | null } | { type: "segment"; response: StreamResponse; percentage: number } | { type: "terminal"; request_id: string; created: string; duration: number; channels: number } | { type: "result"; response: BatchResponse } | { type: "error"; error_code: number | null; error_message: string; provider: string } +export type BatchWord = { word: string; start: number; end: number; confidence: number; channel?: number; speaker: number | null; punctuated_word: string | null } export type DenoiseEvent = { type: "denoiseStarted"; session_id: string } | { type: "denoiseProgress"; session_id: string; percentage: number } | { type: "denoiseCompleted"; session_id: string } | { type: "denoiseFailed"; session_id: string; error: string } export type DenoiseParams = { session_id: string; input_path: string; output_path: string } export type JsonValue = null | boolean | number | string | JsonValue[] | Partial<{ [key in string]: JsonValue }>