diff --git a/.env.example b/.env.example index 5c736b1..bc0d668 100644 --- a/.env.example +++ b/.env.example @@ -11,5 +11,8 @@ ARISTECH_SECRET= AZURE_SUBSCRIPTION_KEY=your_azure_key AZURE_REGION=your_azure_region +# ElevenLabs Configuration +ELEVENLABS_API_KEY=your_elevenlabs_api_key + # Audio Knife Configuration AUDIO_KNIFE_ADDRESS=127.0.0.1:8123 diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 9cc39fd..0f87866 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -11,3 +11,4 @@ ## Code Minimalism - Avoid defensive code unless there is concrete evidence it is necessary. - Avoid redundant logic and repeated calls; keep only the minimal behavior required for correctness. +- Do not add tests unless explicitly requested by the user. diff --git a/Cargo.toml b/Cargo.toml index 23e14d5..56bbbdf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "filter-test", "services/aristech", "services/azure", + "services/elevenlabs", "services/google-transcribe", "services/openai-dialog", "services/playback", @@ -27,6 +28,7 @@ openai-dialog = { path = "services/openai-dialog" } azure = { workspace = true } azure-speech = { workspace = true } aristech = { workspace = true } +elevenlabs = { workspace = true } # basic @@ -84,6 +86,7 @@ context-switch-core = { path = "core" } azure = { path = "services/azure" } playback = { path = "services/playback" } aristech = { path = "services/aristech" } +elevenlabs = { path = "services/elevenlabs" } anyhow = "1.0.102" derive_more = { version = "2.1.1", features = ["full"] } diff --git a/README.md b/README.md index ca142fd..428c236 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Context Switch is a Rust-based framework for building real-time conversational a - Pluggable service architecture - Integration with: - Azure Speech Services (transcription, translation, synthesis) + - ElevenLabs realtime speech-to-text (Scribe v2 Realtime) - OpenAI dialog services - Asynchronous processing using Tokio @@ -16,6 +17,7 @@ Context Switch is a Rust-based framework for building real-time conversational a - `core/`: Core functionality and interfaces - `services/`: Implementation of various service integrations - `azure/`: Azure Speech Services integration + - `elevenlabs/`: ElevenLabs speech-to-text integration - `google-transcribe/`: Google Speech-to-Text integration (WIP) - `openai-dialog/`: OpenAI conversational services integration - `audio-knife/`: WebSocket server that implements the [mod_audio_fork](https://github.com/questnet/freeswitch-modules/tree/questnet/mod_audio_fork) protocol for real-time audio streaming from telephony systems via [FreeSWITCH](https://signalwire.com/freeswitch). Provides a bridge between audio sources and the Context Switch framework. @@ -61,6 +63,9 @@ cargo run --example openai-dialog # Run Azure transcribe example cargo run --example azure-transcribe +# Run ElevenLabs transcribe example +cargo run --example elevenlabs-transcribe + # Run Azure synthesize example cargo run --example azure-synthesize ``` @@ -90,6 +95,9 @@ OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview AZURE_SUBSCRIPTION_KEY=your_azure_key AZURE_REGION=your_azure_region +# ElevenLabs Configuration +ELEVENLABS_API_KEY=your_elevenlabs_key + # Audio Knife Configuration AUDIO_KNIFE_ADDRESS=127.0.0.1:8123 ``` diff --git a/core/Cargo.toml b/core/Cargo.toml index a115fb2..8924c14 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -11,4 +11,6 @@ derive_more = { workspace = true } serde = { workspace = true } # For function calling parameters. -serde_json = { workspace = true } \ No newline at end of file +serde_json = { workspace = true } +isolang = "2.4.0" +oxilangtag = "0.1.5" \ No newline at end of file diff --git a/core/src/language.rs b/core/src/language.rs new file mode 100644 index 0000000..e9216e3 --- /dev/null +++ b/core/src/language.rs @@ -0,0 +1,112 @@ +use std::fmt; + +use isolang::Language; +use oxilangtag::LanguageTag; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum LanguageCodeError { + InvalidBcp47Tag { tag: String, message: String }, + UnsupportedLanguage { language: String }, +} + +impl fmt::Display for LanguageCodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + LanguageCodeError::InvalidBcp47Tag { tag, message } => { + write!(f, "Invalid BCP 47 tag '{tag}': {message}") + } + LanguageCodeError::UnsupportedLanguage { language } => { + write!(f, "Unsupported language subtag '{language}'") + } + } + } +} + +impl std::error::Error for LanguageCodeError {} + +/// Converts a BCP 47 language tag into its ISO 639-3 language code. +/// +/// The conversion uses the primary language subtag only and ignores script, region, variant, +/// and extension subtags. +pub fn bcp47_to_iso639_3(tag: &str) -> Result<&'static str, LanguageCodeError> { + let parsed = LanguageTag::parse(tag).map_err(|error| LanguageCodeError::InvalidBcp47Tag { + tag: tag.to_string(), + message: error.to_string(), + })?; + + let primary_language = parsed.primary_language(); + let language = match primary_language.len() { + 2 => Language::from_639_1(primary_language), + 3 => Language::from_639_3(primary_language), + _ => None, + }; + + language + .map(|x| x.to_639_3()) + .ok_or_else(|| LanguageCodeError::UnsupportedLanguage { + language: primary_language.to_string(), + }) +} + +/// Converts an ISO 639 language code into a BCP 47 language tag. +/// +/// The conversion returns a primary language tag only. If a matching ISO 639-1 code exists, +/// that 2-letter code is preferred (for example `eng` -> `en`). Otherwise the original ISO +/// 639-3 code is used as the BCP 47 primary language subtag. +/// +/// Supports ISO 639-1 (2-letter) and ISO 639-3 (3-letter) input codes. +pub fn iso639_to_bcp47(code: &str) -> Result { + let language = match code.len() { + 2 => Language::from_639_1(code), + 3 => Language::from_639_3(code), + _ => None, + } + .ok_or_else(|| LanguageCodeError::UnsupportedLanguage { + language: code.to_string(), + })?; + + Ok(language + .to_639_1() + .map(str::to_string) + .unwrap_or_else(|| language.to_639_3().to_string())) +} + +/// Converts an ISO 639-3 language code into a BCP 47 language tag. +pub fn iso639_3_to_bcp47(code: &str) -> Result { + iso639_to_bcp47(code) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bcp47_to_iso639_3_for_primary_language_tags() { + assert_eq!(bcp47_to_iso639_3("en").unwrap(), "eng"); + assert_eq!(bcp47_to_iso639_3("de").unwrap(), "deu"); + assert_eq!(bcp47_to_iso639_3("fr").unwrap(), "fra"); + } + + #[test] + fn bcp47_to_iso639_3_ignores_non_primary_subtags() { + assert_eq!(bcp47_to_iso639_3("en-US").unwrap(), "eng"); + assert_eq!(bcp47_to_iso639_3("zh-Hant-TW").unwrap(), "zho"); + } + + #[test] + fn bcp47_to_iso639_3_rejects_malformed_tags() { + let err = bcp47_to_iso639_3("en--US").unwrap_err(); + assert!(matches!(err, LanguageCodeError::InvalidBcp47Tag { .. })); + } + + #[test] + fn bcp47_to_iso639_3_rejects_unsupported_primary_language() { + let err = bcp47_to_iso639_3("qaa").unwrap_err(); + assert_eq!( + err, + LanguageCodeError::UnsupportedLanguage { + language: "qaa".to_string(), + } + ); + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index f2d4d27..48e3ff8 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -3,6 +3,7 @@ pub mod billing_collector; mod billing_context; pub mod conversation; mod duration; +pub mod language; mod protocol; mod registry; pub mod service; diff --git a/examples/elevenlabs-transcribe.rs b/examples/elevenlabs-transcribe.rs new file mode 100644 index 0000000..f230c96 --- /dev/null +++ b/examples/elevenlabs-transcribe.rs @@ -0,0 +1,161 @@ +use std::{env, path::Path, time::Duration}; + +use anyhow::{Context, Result, bail}; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use rodio::DeviceSinkBuilder; +use tokio::{ + select, + sync::mpsc::{channel, unbounded_channel}, +}; + +use context_switch::{ + AudioConsumer, InputModality, OutputModality, services::ElevenLabsTranscribe, +}; +use context_switch_core::{ + AudioFormat, AudioFrame, audio, + conversation::{Conversation, Input}, + service::Service, +}; + +const LANGUAGE: &str = "de-DE"; + +#[tokio::main] +async fn main() -> Result<()> { + dotenvy::dotenv_override()?; + tracing_subscriber::fmt::init(); + + let mut args = env::args(); + match args.len() { + 1 => recognize_from_microphone().await?, + 2 => recognize_from_wav(Path::new(&args.nth(1).unwrap())).await?, + _ => bail!("Invalid number of arguments, expect zero or one"), + } + + Ok(()) +} + +async fn recognize_from_wav(file: &Path) -> Result<()> { + let format = AudioFormat { + channels: 1, + sample_rate: 16_000, + }; + + let frames = playback::audio_file_to_frames(file, format)?; + if frames.is_empty() { + bail!("No frames in the audio file"); + } + + let (producer, input_consumer) = format.new_channel(); + for frame in frames { + producer.produce(frame)?; + } + + recognize(format, input_consumer).await +} + +async fn recognize_from_microphone() -> Result<()> { + // Keep an output sink alive so Bluetooth headsets (e.g. AirPods) can switch to a + // bidirectional profile. Without this, some devices report an input stream of zeros. + let _output_sink = match DeviceSinkBuilder::open_default_sink() { + Ok(sink) => { + println!("Opened default output sink for headset profile"); + Some(sink) + } + Err(e) => { + println!("Warning: Failed to open default output sink: {e}"); + None + } + }; + + let host = cpal::default_host(); + let device = host + .default_input_device() + .context("Failed to get default input device")?; + let config = device + .default_input_config() + .expect("Failed to get default input config"); + + println!("config: {config:?}"); + + let channels = config.channels(); + let sample_rate = config.sample_rate(); + let format = AudioFormat::new(channels, sample_rate); + + let (producer, input_consumer) = format.new_channel(); + + let stream = device + .build_input_stream( + &config.into(), + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let samples = audio::into_i16(data); + + let frame = AudioFrame { format, samples }; + if producer.produce(frame).is_err() { + println!("Failed to send audio data"); + } + }, + move |err| { + eprintln!("Error occurred on stream: {err}"); + }, + Some(Duration::from_secs(1)), + ) + .expect("Failed to build input stream"); + + stream.play().expect("Failed to play stream"); + + recognize(format, input_consumer).await +} + +async fn recognize(format: AudioFormat, mut input_consumer: AudioConsumer) -> Result<()> { + let params = elevenlabs::transcribe::Params { + api_key: env::var("ELEVENLABS_API_KEY").context("ELEVENLABS_API_KEY undefined")?, + model: None, + host: None, + language: Some(LANGUAGE.to_owned()), + include_language_detection: Some(false), + vad_silence_threshold_secs: None, + vad_threshold: None, + min_speech_duration_ms: None, + min_silence_duration_ms: None, + previous_text: None, + }; + + let (output_producer, mut output_consumer) = unbounded_channel(); + let (conv_input_producer, conv_input_consumer) = channel(16_384); + + let service = ElevenLabsTranscribe; + let mut conversation = service.conversation( + params, + Conversation::new( + InputModality::Audio { format }, + [OutputModality::Text, OutputModality::InterimText], + conv_input_consumer, + output_producer, + ), + ); + + loop { + select! { + result = &mut conversation => { + result.context("Conversation stopped")?; + break; + } + input = input_consumer.consume() => { + if let Some(frame) = input { + conv_input_producer.try_send(Input::Audio { frame })?; + } else { + break; + } + } + output = output_consumer.recv() => { + if let Some(output) = output { + println!("{output:?}"); + } else { + break; + } + } + } + } + + Ok(()) +} diff --git a/services/elevenlabs/Cargo.toml b/services/elevenlabs/Cargo.toml new file mode 100644 index 0000000..714bb0e --- /dev/null +++ b/services/elevenlabs/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "elevenlabs" +version = "0.1.0" +edition = "2024" + +[dependencies] +context-switch-core = { workspace = true } + +anyhow = { workspace = true } +async-trait = { workspace = true } +base64 = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "time"] } +tokio-tungstenite = { version = "0.28.0", features = ["connect", "native-tls"] } +tracing = { workspace = true } +url = { workspace = true } diff --git a/services/elevenlabs/src/lib.rs b/services/elevenlabs/src/lib.rs new file mode 100644 index 0000000..c29b4ab --- /dev/null +++ b/services/elevenlabs/src/lib.rs @@ -0,0 +1,3 @@ +pub mod transcribe; + +pub use transcribe::ElevenLabsTranscribe; diff --git a/services/elevenlabs/src/transcribe.rs b/services/elevenlabs/src/transcribe.rs new file mode 100644 index 0000000..152d61d --- /dev/null +++ b/services/elevenlabs/src/transcribe.rs @@ -0,0 +1,546 @@ +use anyhow::{Context, Result, anyhow, bail}; +use async_trait::async_trait; +use base64::Engine; +use futures::{SinkExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::select; +use tokio::sync::mpsc; +use tokio::time::{Duration, sleep}; +use tokio_tungstenite::{ + connect_async_with_config, + tungstenite::{ + Message, + client::IntoClientRequest, + http::{HeaderName, HeaderValue}, + }, +}; +use tracing::{debug, error, warn}; +use url::Url; + +use context_switch_core::{ + AudioFormat, Service, + conversation::{Conversation, ConversationInput, ConversationOutput, Input}, + language::{bcp47_to_iso639_3, iso639_to_bcp47}, +}; + +// Behavior notes of Scribe v2 as of 20260402: +// +// - When `include_language_detection` is enabled, both the committed_transcript and the +// committed_transcribe_with_timestamp are sent in succession (with the same text it seems). +// - When no audio packets are sent for 15 seconds, the socket just closes without any error / +// notification. +// - When a language hint is set, it sometimes translate to the target language. If it does it, +// seems to depend on what language was spoken before. +// - Sometimes when you speak some bogus text, like "Däm, Däm, Däm", the partial_transcript shows it, +// but the committed_transcript is empty. (We could return the partial transcript in this case). + +const DEFAULT_REALTIME_HOST: &str = "wss://api.elevenlabs.io/v1/speech-to-text/realtime"; +const API_KEY_HEADER: &str = "xi-api-key"; +const WRITER_SHUTDOWN_GRACE_PERIOD: Duration = Duration::from_secs(2); +const DEFAULT_MODEL: &str = "scribe_v2_realtime"; +const DEFAULT_INCLUDE_LANGUAGE_DETECTION: bool = false; + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Params { + /// ElevenLabs API key for the `xi-api-key` websocket header. + pub api_key: String, + /// Optional realtime model. Defaults to `scribe_v2_realtime` when omitted. + pub model: Option, + /// Optional websocket endpoint override. + pub host: Option, + /// Optional language hint in BCP 47 format (for example `en-US`). + pub language: Option, + /// Include detected language in timestamped output. + /// When omitted, this integration defaults it to `false`. + pub include_language_detection: Option, + /// VAD silence threshold in seconds. Range: `0.3..=3.0`. Default: `1.5`. + pub vad_silence_threshold_secs: Option, + /// VAD activity threshold. Range: `0.1..=0.9`. Default: `0.4`. + pub vad_threshold: Option, + /// Minimum speech duration in ms. Range: `50..=2000`. Default: `100`. + pub min_speech_duration_ms: Option, + /// Minimum silence duration in ms. Range: `50..=2000`. Default: `100`. + pub min_silence_duration_ms: Option, + /// Optional prior text context sent only with the first `input_audio_chunk`. + pub previous_text: Option, +} + +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AudioEncoding { + #[serde(rename = "pcm_8000")] + Pcm8000, + #[serde(rename = "pcm_16000")] + Pcm16000, + #[serde(rename = "pcm_22050")] + Pcm22050, + #[serde(rename = "pcm_24000")] + Pcm24000, + #[serde(rename = "pcm_44100")] + Pcm44100, + #[serde(rename = "pcm_48000")] + Pcm48000, +} + +impl AudioEncoding { + fn as_str(self) -> &'static str { + match self { + AudioEncoding::Pcm8000 => "pcm_8000", + AudioEncoding::Pcm16000 => "pcm_16000", + AudioEncoding::Pcm22050 => "pcm_22050", + AudioEncoding::Pcm24000 => "pcm_24000", + AudioEncoding::Pcm44100 => "pcm_44100", + AudioEncoding::Pcm48000 => "pcm_48000", + } + } +} + +#[derive(Debug)] +pub struct ElevenLabsTranscribe; + +#[derive(Debug, Clone, Copy)] +struct ConversationLoopConfig { + input_format: AudioFormat, + include_language_detection: bool, +} + +#[async_trait] +impl Service for ElevenLabsTranscribe { + type Params = Params; + + async fn conversation(&self, params: Params, conversation: Conversation) -> Result<()> { + let input_format = conversation.require_audio_input()?; + conversation.require_text_output(true)?; + + if input_format.channels != 1 { + bail!("ElevenLabs realtime currently requires mono input audio"); + } + + let include_language_detection = params + .include_language_detection + .unwrap_or(DEFAULT_INCLUDE_LANGUAGE_DETECTION); + + let encoding = resolve_audio_encoding(input_format)?; + let endpoint = build_endpoint(¶ms, encoding, include_language_detection)?; + + let mut request = endpoint + .as_str() + .into_client_request() + .context("Building websocket request")?; + request.headers_mut().insert( + HeaderName::from_static(API_KEY_HEADER), + HeaderValue::from_str(¶ms.api_key).context("Invalid xi-api-key header value")?, + ); + + // Disable Nagle (TCP_NODELAY) to reduce latency for realtime audio chunk streaming. + let (socket, _) = connect_async_with_config(request, None, true) + .await + .context("Connecting to ElevenLabs realtime websocket")?; + + let (write, mut read) = socket.split(); + let (mut input, output) = conversation.start()?; + let (outbound_tx, outbound_rx) = mpsc::unbounded_channel(); + let writer_task = tokio::spawn(run_writer(write, outbound_rx)); + let mut outbound_closed = false; + + let conversation_result = run_conversation_loop( + &mut input, + &output, + &mut read, + &outbound_tx, + &mut outbound_closed, + ConversationLoopConfig { + input_format, + include_language_detection, + }, + params.previous_text.as_deref(), + ) + .await; + + if !outbound_closed { + let _ = outbound_tx.send(OutboundMessage::Close); + } + + drop(outbound_tx); + + let shutdown_result = shutdown_writer_task(writer_task).await; + + conversation_result?; + shutdown_result + } +} + +async fn run_conversation_loop( + input: &mut ConversationInput, + output: &ConversationOutput, + read: &mut R, + outbound_tx: &mpsc::UnboundedSender, + outbound_closed: &mut bool, + config: ConversationLoopConfig, + mut previous_text_for_next_chunk: Option<&str>, +) -> Result<()> +where + R: futures::Stream> + Unpin, +{ + let mut input_closed = false; + + loop { + select! { + input_event = input.recv(), if !input_closed => { + match input_event { + Some(Input::Audio { frame }) => { + if frame.format != config.input_format { + bail!("Received mixed input audio formats in conversation"); + } + + let previous_text = previous_text_for_next_chunk.take(); + let msg = build_audio_chunk_message(frame, false, previous_text)?; + outbound_tx + .send(msg) + .map_err(|_| anyhow!("ElevenLabs websocket writer task stopped unexpectedly"))?; + } + Some(_) => {} + None => { + input_closed = true; + if !*outbound_closed { + let _ = outbound_tx.send(OutboundMessage::Close); + *outbound_closed = true; + } + } + } + } + msg = read.next() => { + match msg { + Some(Ok(message)) => { + process_server_message(message, output, config.include_language_detection)?; + } + Some(Err(e)) => { + bail!("Error reading ElevenLabs websocket: {e}"); + } + None => return Ok(()), + } + } + } + } +} + +async fn shutdown_writer_task(mut writer_task: tokio::task::JoinHandle>) -> Result<()> { + select! { + join_result = &mut writer_task => { + match join_result { + Ok(result) => result, + Err(e) => bail!("ElevenLabs websocket writer task failed to join: {e}"), + } + } + _ = sleep(WRITER_SHUTDOWN_GRACE_PERIOD) => { + warn!( + "ElevenLabs writer shutdown grace period reached; aborting writer task after {:?}", + WRITER_SHUTDOWN_GRACE_PERIOD + ); + writer_task.abort(); + let _ = writer_task.await; + Ok(()) + } + } +} + +fn resolve_audio_encoding(input_format: AudioFormat) -> Result { + let encoding = match input_format.sample_rate { + 8_000 => AudioEncoding::Pcm8000, + 16_000 => AudioEncoding::Pcm16000, + 22_050 => AudioEncoding::Pcm22050, + 24_000 => AudioEncoding::Pcm24000, + 44_100 => AudioEncoding::Pcm44100, + 48_000 => AudioEncoding::Pcm48000, + _ => { + bail!( + "Unsupported input sample rate {} for ElevenLabs realtime. Supported sample rates: 8000, 16000, 22050, 24000, 44100, 48000 Hz", + input_format.sample_rate + ) + } + }; + + Ok(encoding) +} + +fn build_endpoint( + params: &Params, + audio_encoding: AudioEncoding, + include_language_detection: bool, +) -> Result { + let host = params.host.as_deref().unwrap_or(DEFAULT_REALTIME_HOST); + let mut url = Url::parse(host).context("Invalid ElevenLabs realtime host URL")?; + + { + let mut q = url.query_pairs_mut(); + q.append_pair("model_id", params.model.as_deref().unwrap_or(DEFAULT_MODEL)); + // Defaulting to false enables automatic translation to the requested language. + q.append_pair( + "include_language_detection", + if include_language_detection { + "true" + } else { + "false" + }, + ); + q.append_pair("audio_format", audio_encoding.as_str()); + q.append_pair("commit_strategy", "vad"); + + if let Some(language) = params.language.as_deref() { + let language_code = bcp47_to_iso639_3(language).map_err(|error| { + anyhow!("Invalid ElevenLabs params.language '{language}': {error}") + })?; + q.append_pair("language_code", language_code); + } + if let Some(vad_silence_threshold_secs) = params.vad_silence_threshold_secs { + q.append_pair( + "vad_silence_threshold_secs", + &vad_silence_threshold_secs.to_string(), + ); + } + if let Some(vad_threshold) = params.vad_threshold { + q.append_pair("vad_threshold", &vad_threshold.to_string()); + } + if let Some(min_speech_duration_ms) = params.min_speech_duration_ms { + q.append_pair( + "min_speech_duration_ms", + &min_speech_duration_ms.to_string(), + ); + } + if let Some(min_silence_duration_ms) = params.min_silence_duration_ms { + q.append_pair( + "min_silence_duration_ms", + &min_silence_duration_ms.to_string(), + ); + } + // Intentionally omit `enable_logging`: the provider default is `true`. + // `enable_logging=false` (zero retention mode) is enterprise-only. + } + + Ok(url) +} + +fn build_audio_chunk_message( + frame: context_switch_core::AudioFrame, + commit: bool, + previous_text: Option<&str>, +) -> Result { + let request = InputAudioChunk { + message_type: "input_audio_chunk", + audio_base_64: base64::engine::general_purpose::STANDARD.encode(frame.to_le_bytes()), + commit, + sample_rate: frame.format.sample_rate, + previous_text, + }; + + let json = serde_json::to_string(&request).context("Serializing input audio chunk")?; + Ok(OutboundMessage::Ws(Message::Text(json.into()))) +} + +enum OutboundMessage { + Ws(Message), + Close, +} + +async fn run_writer( + mut write: S, + mut outbound_rx: mpsc::UnboundedReceiver, +) -> Result<()> +where + S: futures::Sink + Unpin, +{ + while let Some(outbound) = outbound_rx.recv().await { + match outbound { + OutboundMessage::Ws(message) => { + write + .send(message) + .await + .context("Sending input audio chunk")?; + } + OutboundMessage::Close => { + write + .close() + .await + .context("Closing websocket write stream")?; + return Ok(()); + } + } + } + + Ok(()) +} + +#[derive(Debug, Serialize)] +struct InputAudioChunk<'a> { + message_type: &'static str, + audio_base_64: String, + commit: bool, + sample_rate: u32, + #[serde(skip_serializing_if = "Option::is_none")] + previous_text: Option<&'a str>, +} + +fn process_server_message( + message: Message, + output: &ConversationOutput, + include_language_detection: bool, +) -> Result<()> { + match message { + Message::Text(text) => { + debug!("ElevenLabs websocket received: {}", text); + process_server_json(text.as_str(), output, include_language_detection) + } + Message::Binary(_) => Ok(()), + Message::Ping(payload) => { + error!( + "Received ElevenLabs websocket ping ({} bytes payload)", + payload.len() + ); + Ok(()) + } + Message::Pong(_) => Ok(()), + Message::Close(_) => Ok(()), + Message::Frame(_) => Ok(()), + } +} + +fn process_server_json( + json: &str, + output: &ConversationOutput, + include_language_detection: bool, +) -> Result<()> { + let envelope: RealtimeEnvelope = serde_json::from_str(json) + .with_context(|| format!("Parsing ElevenLabs server event: {json}"))?; + + match envelope.message_type.as_str() { + "session_started" => { + debug!("ElevenLabs session started"); + Ok(()) + } + "partial_transcript" => { + let event: PartialTranscript = serde_json::from_value(envelope.payload)?; + output.text(false, event.text, None) + } + "committed_transcript" => { + if include_language_detection { + // Ignoring committed_transcript because include_language_detection=true; expecting committed_transcript_with_timestamps + return Ok(()); + } + let event: CommittedTranscript = serde_json::from_value(envelope.payload)?; + output.text(true, event.text, None) + } + "committed_transcript_with_timestamps" => { + let event: CommittedTranscriptWithTimestamps = + serde_json::from_value(envelope.payload.clone())?; + let CommittedTranscriptWithTimestamps { + text, + language_code: detected_language_code, + words: _, + } = event; + + let language_code = detected_language_code + .as_deref() + .and_then(|detected_language| { + match iso639_to_bcp47(detected_language) { + Ok(code) => Some(code.to_string()), + Err(err) => { + error!( + "Failed to convert detected language code '{}' from ISO 639 to BCP47: {}", + detected_language, + err + ); + None + } + } + }); + + output.text(true, text, language_code) + } + // Not in the official documentation, but this happens when the language code is invalid. + "invalid_request" => { + let event: InvalidRequest = serde_json::from_value(envelope.payload)?; + let message = event + .message + .or(event.error) + .unwrap_or_else(|| "ElevenLabs realtime rejected the request".to_owned()); + bail!("ElevenLabs invalid_request: {message}") + } + message_type if is_scribe_error_type(message_type) => { + let message = extract_error_message(&envelope.payload) + .unwrap_or_else(|| "ElevenLabs realtime returned an unspecified error".to_owned()); + bail!("ElevenLabs {message_type}: {message}") + } + _ => { + debug!( + "Ignoring ElevenLabs realtime event: {}", + envelope.message_type + ); + Ok(()) + } + } +} + +#[derive(Debug, Deserialize)] +struct RealtimeEnvelope { + message_type: String, + #[serde(flatten)] + payload: Value, +} + +#[derive(Debug, Deserialize)] +struct PartialTranscript { + text: String, +} + +#[derive(Debug, Deserialize)] +struct CommittedTranscript { + text: String, +} + +#[derive(Debug, Deserialize)] +struct CommittedTranscriptWithTimestamps { + text: String, + language_code: Option, + #[allow(dead_code)] + words: Option>, +} + +#[derive(Debug, Deserialize)] +struct InvalidRequest { + message: Option, + error: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize, Serialize)] +struct WordTimestamp { + text: String, + start: f64, + end: f64, + #[serde(rename = "type")] + kind: String, + #[serde(skip_serializing_if = "Option::is_none")] + logprob: Option, + #[serde(skip_serializing_if = "Option::is_none")] + characters: Option>, +} + +fn is_scribe_error_type(message_type: &str) -> bool { + message_type == "scribe_error" + || (message_type.starts_with("scribe_") && message_type.ends_with("_error")) +} + +fn extract_error_message(payload: &Value) -> Option { + payload + .get("message") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + .or_else(|| { + payload + .get("error") + .and_then(Value::as_str) + .map(ToOwned::to_owned) + }) +} diff --git a/src/context_switch.rs b/src/context_switch.rs index a84d202..bbb2817 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -47,6 +47,7 @@ pub fn registry() -> Registry { .add_service("azure-transcribe", azure::AzureTranscribe) .add_service("azure-synthesize", azure::AzureSynthesize) .add_service("azure-translate", azure::AzureTranslate) + .add_service("elevenlabs-transcribe", elevenlabs::ElevenLabsTranscribe) .add_service("openai-dialog", openai_dialog::OpenAIDialog) .add_service("aristech-transcribe", aristech::AristechTranscribe) .add_service("aristech-synthesize", aristech::AristechSynthesize) diff --git a/src/lib.rs b/src/lib.rs index 84adcdc..6689b98 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,4 +14,5 @@ pub use speech_gate::make_speech_gate_processor; pub mod services { pub use aristech::AristechTranscribe; pub use azure::AzureTranscribe; + pub use elevenlabs::ElevenLabsTranscribe; }