diff --git a/Cargo.lock b/Cargo.lock index 1167f19f6e..0ea829b9ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2793,6 +2793,7 @@ dependencies = [ "insta", "language", "llm-types", + "model-manager", "serde", "serde_json", "thiserror 2.0.18", @@ -10451,6 +10452,7 @@ dependencies = [ "cactus", "futures-util", "llm-types", + "model-manager", "serde", "serde_json", "thiserror 2.0.18", @@ -11102,6 +11104,15 @@ dependencies = [ "zip 2.4.2", ] +[[package]] +name = "model-manager" +version = "0.1.0" +dependencies = [ + "thiserror 2.0.18", + "tokio", + "uuid", +] + [[package]] name = "moxcms" version = "0.7.11" @@ -12518,6 +12529,7 @@ version = "0.0.1" dependencies = [ "am", "audio-utils", + "backon", "base64 0.22.1", "bytes", "cactus-model", @@ -18416,6 +18428,7 @@ dependencies = [ "axum 0.8.8", "axum-extra", "backon", + "cactus", "cactus-model", "data", "dirs 6.0.0", @@ -18427,6 +18440,7 @@ dependencies = [ "language", "local-model", "model-downloader", + "model-manager", "owhisper-client", "owhisper-interface", "port-killer", @@ -20281,6 +20295,7 @@ dependencies = [ "futures-util", "insta", "language", + "model-manager", "owhisper-client", "owhisper-interface", "pico-args", @@ -23233,6 +23248,7 @@ name = "ws-client" version = "0.1.0" dependencies = [ "async-stream", + "backon", "bytes", "futures-util", "serde", diff --git a/Cargo.toml b/Cargo.toml index c0f5c925b7..5e2b14e755 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,6 +119,7 @@ hypr-mac = { path = "crates/mac", package = "mac" } hypr-mcp = { path = "crates/mcp", package = "mcp" } hypr-mobile-bridge = { path = "crates/mobile-bridge", package = "mobile-bridge" } hypr-model-downloader = { path = "crates/model-downloader", package = "model-downloader" } +hypr-model-manager = { path = "crates/model-manager", package = "model-manager" } hypr-mp3 = { path = "crates/mp3", package = "mp3" } hypr-nango = { path = "crates/nango", package = "nango" } hypr-notification = { path = "crates/notification", package = "notification" } diff --git a/apps/desktop/src/stt/contexts.test.tsx b/apps/desktop/src/stt/contexts.test.tsx index a78f9d5466..bb293f84ee 100644 --- a/apps/desktop/src/stt/contexts.test.tsx +++ b/apps/desktop/src/stt/contexts.test.tsx @@ -41,7 +41,7 @@ describe("ListenerProvider detect events", () => { listenMock.mockResolvedValue(() => {}); }); - test("does not stop listening when MicStopped arrives", async () => { + test("stops listening when MicStopped arrives", async () => { const store = createListenerStore(); const stopSpy = vi.fn(); @@ -65,7 +65,7 @@ describe("ListenerProvider detect events", () => { }, }); - expect(stopSpy).not.toHaveBeenCalled(); + expect(stopSpy).toHaveBeenCalledTimes(1); }); test("stops listening when sleep starts", async () => { diff --git a/crates/cactus/Cargo.toml b/crates/cactus/Cargo.toml index 762080ffee..fd636f7388 100644 --- a/crates/cactus/Cargo.toml +++ b/crates/cactus/Cargo.toml @@ -4,11 +4,16 @@ version = "0.1.0" edition = "2024" license = "MIT" +[features] +default = [] +model-manager = ["dep:hypr-model-manager"] + [dependencies] cactus-sys = { git = "https://github.com/cactus-compute/cactus", package = "cactus-sys", rev = "a5acad3" } hypr-language = { workspace = true } hypr-llm-types = { workspace = true } +hypr-model-manager = { workspace = true, optional = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt", "sync"] } diff --git a/crates/cactus/src/lib.rs b/crates/cactus/src/lib.rs index 20061391e9..c0f1fe0e08 100644 --- a/crates/cactus/src/lib.rs +++ b/crates/cactus/src/lib.rs @@ -17,3 +17,12 @@ pub use stt::{ pub use vad::{VadOptions, VadResult, VadSegment}; pub use hypr_llm_types::{Response, StreamingParser}; + +#[cfg(feature = "model-manager")] +impl hypr_model_manager::ModelLoader for Model { + type Error = Error; + + fn load(path: &std::path::Path) -> Result { + Model::new(path) + } +} diff --git a/crates/listener-core/src/actors/listener/adapters.rs b/crates/listener-core/src/actors/listener/adapters.rs index ea7def476e..e514826ff2 100644 --- a/crates/listener-core/src/actors/listener/adapters.rs +++ b/crates/listener-core/src/actors/listener/adapters.rs @@ -30,79 +30,80 @@ pub(super) async fn spawn_rx_task( let adapter_kind = AdapterKind::from_url_and_languages(&args.base_url, &args.languages, Some(&args.model)); let is_dual = matches!(args.mode, crate::actors::ChannelMode::MicAndSpeaker); + let policy = connect_policy_for(adapter_kind); let result = match (adapter_kind, is_dual) { (AdapterKind::Argmax, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Argmax, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Soniox, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Soniox, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Fireworks, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Fireworks, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Deepgram, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Deepgram, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::AssemblyAI, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::AssemblyAI, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::OpenAI, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::OpenAI, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Gladia, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Gladia, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::ElevenLabs, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::ElevenLabs, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::DashScope, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::DashScope, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Mistral, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Mistral, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Hyprnote, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Hyprnote, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } (AdapterKind::Cactus, false) => { - spawn_rx_task_single_with_adapter::(args, myself).await + spawn_rx_task_single_with_adapter::(args, myself, policy).await } (AdapterKind::Cactus, true) => { - spawn_rx_task_dual_with_adapter::(args, myself).await + spawn_rx_task_dual_with_adapter::(args, myself, policy).await } }?; @@ -148,9 +149,29 @@ fn desktop_connect_policy() -> hypr_ws_client::client::WebSocketConnectPolicy { } } +fn local_model_connect_policy() -> hypr_ws_client::client::WebSocketConnectPolicy { + hypr_ws_client::client::WebSocketConnectPolicy { + connect_timeout: Duration::from_secs(10), + max_attempts: 15, + retry_delay: Duration::from_secs(5), + } +} + +fn connect_policy_for( + kind: owhisper_client::AdapterKind, +) -> hypr_ws_client::client::WebSocketConnectPolicy { + match kind { + owhisper_client::AdapterKind::Cactus | owhisper_client::AdapterKind::Argmax => { + local_model_connect_policy() + } + _ => desktop_connect_policy(), + } +} + async fn spawn_rx_task_single_with_adapter( args: ListenerArgs, myself: ActorRef, + policy: hypr_ws_client::client::WebSocketConnectPolicy, ) -> Result< ( ChannelSender, @@ -169,7 +190,7 @@ async fn spawn_rx_task_single_with_adapter( .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) - .connect_policy(desktop_connect_policy()) + .connect_policy(policy) .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()) .build_single() .await; @@ -211,6 +232,7 @@ async fn spawn_rx_task_single_with_adapter( async fn spawn_rx_task_dual_with_adapter( args: ListenerArgs, myself: ActorRef, + policy: hypr_ws_client::client::WebSocketConnectPolicy, ) -> Result< ( ChannelSender, @@ -229,7 +251,7 @@ async fn spawn_rx_task_dual_with_adapter( .api_base(args.base_url.clone()) .api_key(args.api_key.clone()) .params(build_listen_params(&args)) - .connect_policy(desktop_connect_policy()) + .connect_policy(policy) .extra_header(DEVICE_FINGERPRINT_HEADER, hypr_host::fingerprint()) .build_dual() .await; diff --git a/crates/listener-core/src/events.rs b/crates/listener-core/src/events.rs index 0b8286945d..a98dc66bd5 100644 --- a/crates/listener-core/src/events.rs +++ b/crates/listener-core/src/events.rs @@ -39,6 +39,8 @@ pub enum SessionProgressEvent { Connecting { session_id: String }, #[serde(rename = "connected")] Connected { session_id: String, adapter: String }, + #[serde(rename = "model_loading")] + ModelLoading { session_id: String }, } #[derive(serde::Serialize, serde::Deserialize, Clone)] diff --git a/crates/listener2-core/Cargo.toml b/crates/listener2-core/Cargo.toml index ec98d88a93..83f7e93b7a 100644 --- a/crates/listener2-core/Cargo.toml +++ b/crates/listener2-core/Cargo.toml @@ -17,7 +17,7 @@ hypr-language = { workspace = true } bytes = { workspace = true } hound = { workspace = true } -owhisper-client = { workspace = true, features = ["argmax"] } +owhisper-client = { workspace = true, features = ["local"] } owhisper-interface = { workspace = true } serde = { workspace = true } diff --git a/crates/listener2-core/src/events.rs b/crates/listener2-core/src/events.rs index 0f5254dd13..ab2db2bb20 100644 --- a/crates/listener2-core/src/events.rs +++ b/crates/listener2-core/src/events.rs @@ -47,6 +47,8 @@ pub enum BatchEvent { code: BatchErrorCode, error: String, }, + #[serde(rename = "modelLoading")] + ModelLoading { session_id: String }, } #[derive(serde::Serialize, Clone)] diff --git a/crates/llm-cactus/Cargo.toml b/crates/llm-cactus/Cargo.toml index d1bffaaf2b..9eca9c690f 100644 --- a/crates/llm-cactus/Cargo.toml +++ b/crates/llm-cactus/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" edition = "2024" [dependencies] -hypr-cactus = { workspace = true } +hypr-cactus = { workspace = true, features = ["model-manager"] } hypr-llm-types = { workspace = true } +hypr-model-manager = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } diff --git a/crates/llm-cactus/src/error.rs b/crates/llm-cactus/src/error.rs index 6467a1ed87..3f1d51674f 100644 --- a/crates/llm-cactus/src/error.rs +++ b/crates/llm-cactus/src/error.rs @@ -2,12 +2,6 @@ pub enum Error { #[error(transparent)] Cactus(#[from] hypr_cactus::Error), - #[error("model not registered: {0}")] - ModelNotRegistered(String), - #[error("model file not found: {0}")] - ModelFileNotFound(String), - #[error("no default model configured")] - NoDefaultModel, - #[error("worker task panicked")] - WorkerPanicked, + #[error(transparent)] + ModelManager(#[from] hypr_model_manager::Error), } diff --git a/crates/llm-cactus/src/manager.rs b/crates/llm-cactus/src/manager.rs index c7ae2e662c..758fbae130 100644 --- a/crates/llm-cactus/src/manager.rs +++ b/crates/llm-cactus/src/manager.rs @@ -1,321 +1,7 @@ -use std::{ - collections::HashMap, - path::{Path, PathBuf}, - sync::Arc, - time::Duration, +pub use hypr_model_manager::{ + ModelLoader, ModelManager as GenericModelManager, + ModelManagerBuilder as GenericModelManagerBuilder, }; -use tokio::sync::{Mutex, RwLock, watch}; - -const DEFAULT_INACTIVITY_TIMEOUT: Duration = Duration::from_secs(150); -const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(3); - -pub trait ModelLoader: Send + Sync + 'static { - fn load(path: &Path) -> Result - where - Self: Sized; -} - -impl ModelLoader for hypr_cactus::Model { - fn load(path: &Path) -> Result { - Ok(hypr_cactus::Model::new(path)?) - } -} - -struct ActiveModel { - name: String, - model: Arc, -} - -struct DropGuard { - shutdown_tx: watch::Sender<()>, -} - -impl Drop for DropGuard { - fn drop(&mut self) { - let _ = self.shutdown_tx.send(()); - } -} - -pub struct ModelManager { - registry: Arc>>, - default_model: Arc>>, - active: Arc>>>, - last_activity: Arc>>, - inactivity_timeout: Duration, - _drop_guard: Arc, -} - -impl Clone for ModelManager { - fn clone(&self) -> Self { - Self { - registry: Arc::clone(&self.registry), - default_model: Arc::clone(&self.default_model), - active: Arc::clone(&self.active), - last_activity: Arc::clone(&self.last_activity), - inactivity_timeout: self.inactivity_timeout, - _drop_guard: Arc::clone(&self._drop_guard), - } - } -} - -impl ModelManager { - pub fn builder() -> ModelManagerBuilder { - ModelManagerBuilder::default() - } - - pub async fn register(&self, name: impl Into, path: impl Into) { - let mut reg = self.registry.write().await; - reg.insert(name.into(), path.into()); - } - - pub async fn unregister(&self, name: &str) { - let mut reg = self.registry.write().await; - reg.remove(name); - - let mut active = self.active.lock().await; - if active.as_ref().is_some_and(|a| a.name == name) { - *active = None; - } - } - - pub async fn set_default(&self, name: impl Into) { - let mut default = self.default_model.write().await; - *default = Some(name.into()); - } - - pub async fn get(&self, name: Option<&str>) -> Result, crate::Error> { - let resolved = match name { - Some(n) => n.to_string(), - None => { - let default = self.default_model.read().await; - default.clone().ok_or(crate::Error::NoDefaultModel)? - } - }; - - let path = { - let reg = self.registry.read().await; - reg.get(&resolved) - .cloned() - .ok_or_else(|| crate::Error::ModelNotRegistered(resolved.clone()))? - }; - - if !path.exists() { - return Err(crate::Error::ModelFileNotFound(path.display().to_string())); - } - - self.update_activity().await; - - let mut active = self.active.lock().await; - - if let Some(ref a) = *active { - if a.name == resolved { - return Ok(Arc::clone(&a.model)); - } - } - - *active = None; - - let model = tokio::task::spawn_blocking(move || M::load(&path)) - .await - .map_err(|_| crate::Error::WorkerPanicked)??; - - let model = Arc::new(model); - *active = Some(ActiveModel { - name: resolved, - model: Arc::clone(&model), - }); - - Ok(model) - } - - async fn update_activity(&self) { - *self.last_activity.lock().await = Some(tokio::time::Instant::now()); - } - - fn spawn_monitor(&self, check_interval: Duration, mut shutdown_rx: watch::Receiver<()>) { - let active = Arc::clone(&self.active); - let last_activity = Arc::clone(&self.last_activity); - let inactivity_timeout = self.inactivity_timeout; - - tokio::spawn(async move { - let mut interval = tokio::time::interval(check_interval); - interval.tick().await; - - loop { - tokio::select! { - _ = shutdown_rx.changed() => break, - _ = interval.tick() => { - let last = last_activity.lock().await; - if let Some(t) = *last { - if t.elapsed() > inactivity_timeout { - *active.lock().await = None; - } - } - } - } - } - }); - } -} - -pub struct ModelManagerBuilder { - models: HashMap, - default_model: Option, - inactivity_timeout: Option, - check_interval: Option, - _phantom: std::marker::PhantomData, -} - -impl Default for ModelManagerBuilder { - fn default() -> Self { - Self { - models: HashMap::new(), - default_model: None, - inactivity_timeout: None, - check_interval: None, - _phantom: std::marker::PhantomData, - } - } -} - -impl ModelManagerBuilder { - pub fn register(mut self, name: impl Into, path: impl Into) -> Self { - self.models.insert(name.into(), path.into()); - self - } - - pub fn default_model(mut self, name: impl Into) -> Self { - self.default_model = Some(name.into()); - self - } - - pub fn inactivity_timeout(mut self, timeout: Duration) -> Self { - self.inactivity_timeout = Some(timeout); - self - } - - pub fn check_interval(mut self, interval: Duration) -> Self { - self.check_interval = Some(interval); - self - } - - pub fn build(self) -> ModelManager { - let (shutdown_tx, shutdown_rx) = watch::channel(()); - let inactivity_timeout = self - .inactivity_timeout - .unwrap_or(DEFAULT_INACTIVITY_TIMEOUT); - let check_interval = self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL); - - let manager = ModelManager { - registry: Arc::new(RwLock::new(self.models)), - default_model: Arc::new(RwLock::new(self.default_model)), - active: Arc::new(Mutex::new(None)), - last_activity: Arc::new(Mutex::new(None)), - inactivity_timeout, - _drop_guard: Arc::new(DropGuard { shutdown_tx }), - }; - - manager.spawn_monitor(check_interval, shutdown_rx); - manager - } -} - -#[cfg(test)] -mod tests { - use super::*; - - struct MockModel; - - impl ModelLoader for MockModel { - fn load(_path: &Path) -> Result { - Ok(MockModel) - } - } - - fn temp_model_path() -> PathBuf { - let dir = std::env::temp_dir().join("llm-cactus-tests"); - std::fs::create_dir_all(&dir).unwrap(); - let path = dir.join(format!("{}.bin", uuid::Uuid::new_v4())); - std::fs::write(&path, b"").unwrap(); - path - } - - fn build_manager( - timeout: Duration, - check_interval: Duration, - models: &[(&str, PathBuf)], - ) -> ModelManager { - let mut builder = ModelManager::::builder() - .inactivity_timeout(timeout) - .check_interval(check_interval); - for (name, path) in models { - builder = builder.register(*name, path.clone()); - } - builder.build() - } - - #[tokio::test(start_paused = true)] - async fn idle_model_gets_evicted() { - let path = temp_model_path(); - let mgr = build_manager( - Duration::from_millis(100), - Duration::from_millis(10), - &[("a", path)], - ); - - let m1 = mgr.get(Some("a")).await.unwrap(); - let m2 = mgr.get(Some("a")).await.unwrap(); - assert!(Arc::ptr_eq(&m1, &m2)); - - tokio::time::advance(Duration::from_millis(120)).await; - tokio::task::yield_now().await; - - let m3 = mgr.get(Some("a")).await.unwrap(); - assert!(!Arc::ptr_eq(&m1, &m3)); - } - - #[tokio::test(start_paused = true)] - async fn activity_prevents_eviction() { - let path = temp_model_path(); - let mgr = build_manager( - Duration::from_millis(100), - Duration::from_millis(10), - &[("a", path)], - ); - - let m1 = mgr.get(Some("a")).await.unwrap(); - - for _ in 0..5 { - tokio::time::advance(Duration::from_millis(50)).await; - tokio::task::yield_now().await; - - let m = mgr.get(Some("a")).await.unwrap(); - assert!(Arc::ptr_eq(&m1, &m)); - } - } - - #[tokio::test(start_paused = true)] - async fn access_near_timeout_resets_timer() { - let path = temp_model_path(); - let mgr = build_manager( - Duration::from_millis(100), - Duration::from_millis(10), - &[("a", path)], - ); - - let m1 = mgr.get(Some("a")).await.unwrap(); - - tokio::time::advance(Duration::from_millis(90)).await; - tokio::task::yield_now().await; - - let m2 = mgr.get(Some("a")).await.unwrap(); - assert!(Arc::ptr_eq(&m1, &m2)); - - tokio::time::advance(Duration::from_millis(50)).await; - tokio::task::yield_now().await; - - let m3 = mgr.get(Some("a")).await.unwrap(); - assert!(Arc::ptr_eq(&m1, &m3)); - } -} +pub type ModelManager = GenericModelManager; +pub type ModelManagerBuilder = GenericModelManagerBuilder; diff --git a/crates/model-manager/Cargo.toml b/crates/model-manager/Cargo.toml new file mode 100644 index 0000000000..272f3ede9a --- /dev/null +++ b/crates/model-manager/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "model-manager" +version = "0.1.0" +edition = "2024" + +[dependencies] +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync", "time", "rt", "macros"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["test-util", "macros", "rt"] } +uuid = { workspace = true, features = ["v4"] } diff --git a/crates/model-manager/src/builder.rs b/crates/model-manager/src/builder.rs new file mode 100644 index 0000000000..2446e68990 --- /dev/null +++ b/crates/model-manager/src/builder.rs @@ -0,0 +1,66 @@ +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; + +use tokio::sync::{RwLock, watch}; + +use crate::loader::ModelLoader; +use crate::manager::{DEFAULT_CHECK_INTERVAL, DEFAULT_INACTIVITY_TIMEOUT, ModelManager}; + +pub struct ModelManagerBuilder { + models: HashMap, + default_model: Option, + inactivity_timeout: Option, + check_interval: Option, + _phantom: std::marker::PhantomData, +} + +impl Default for ModelManagerBuilder { + fn default() -> Self { + Self { + models: HashMap::new(), + default_model: None, + inactivity_timeout: None, + check_interval: None, + _phantom: std::marker::PhantomData, + } + } +} + +impl ModelManagerBuilder { + pub fn register(mut self, name: impl Into, path: impl Into) -> Self { + self.models.insert(name.into(), path.into()); + self + } + + pub fn default_model(mut self, name: impl Into) -> Self { + self.default_model = Some(name.into()); + self + } + + pub fn inactivity_timeout(mut self, timeout: Duration) -> Self { + self.inactivity_timeout = Some(timeout); + self + } + + pub fn check_interval(mut self, interval: Duration) -> Self { + self.check_interval = Some(interval); + self + } + + pub fn build(self) -> ModelManager { + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let inactivity_timeout = self + .inactivity_timeout + .unwrap_or(DEFAULT_INACTIVITY_TIMEOUT); + let check_interval = self.check_interval.unwrap_or(DEFAULT_CHECK_INTERVAL); + + let manager = ModelManager::new( + Arc::new(RwLock::new(self.models)), + Arc::new(RwLock::new(self.default_model)), + inactivity_timeout, + shutdown_tx, + ); + + manager.spawn_monitor(check_interval, shutdown_rx); + manager + } +} diff --git a/crates/model-manager/src/lib.rs b/crates/model-manager/src/lib.rs new file mode 100644 index 0000000000..93b75eee15 --- /dev/null +++ b/crates/model-manager/src/lib.rs @@ -0,0 +1,10 @@ +mod builder; +mod loader; +mod manager; + +#[cfg(test)] +mod tests; + +pub use builder::ModelManagerBuilder; +pub use loader::{Error, ModelLoader, ModelStatus, TryGetResult}; +pub use manager::ModelManager; diff --git a/crates/model-manager/src/loader.rs b/crates/model-manager/src/loader.rs new file mode 100644 index 0000000000..6989456fd1 --- /dev/null +++ b/crates/model-manager/src/loader.rs @@ -0,0 +1,38 @@ +use std::{fmt, path::Path, sync::Arc}; + +pub trait ModelLoader: Send + Sync + 'static { + type Error: fmt::Display + fmt::Debug + Send + 'static; + + fn load(path: &Path) -> Result + where + Self: Sized; +} + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("model not registered: {0}")] + ModelNotRegistered(String), + #[error("model file not found: {0}")] + ModelFileNotFound(String), + #[error("no default model configured")] + NoDefaultModel, + #[error("worker task panicked")] + WorkerPanicked, + #[error(transparent)] + Load(E), +} + +pub enum TryGetResult { + Ready(Arc), + Loading, + NotRegistered, + Failed(String), +} + +pub enum ModelStatus { + Ready(Arc), + Loading, + Idle, + NotRegistered, + Failed(String), +} diff --git a/crates/model-manager/src/manager.rs b/crates/model-manager/src/manager.rs new file mode 100644 index 0000000000..3e1d90cba5 --- /dev/null +++ b/crates/model-manager/src/manager.rs @@ -0,0 +1,343 @@ +use std::{collections::HashMap, path::PathBuf, sync::Arc, time::Duration}; + +use tokio::sync::{Mutex, RwLock, watch}; + +use crate::loader::{Error, ModelLoader, ModelStatus, TryGetResult}; + +pub(crate) const DEFAULT_INACTIVITY_TIMEOUT: Duration = Duration::from_secs(150); +pub(crate) const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(3); + +pub(crate) struct ActiveModel { + pub(crate) name: String, + pub(crate) model: Arc, +} + +pub(crate) enum LoadState { + Loading(tokio::task::JoinHandle>), + Failed(String), +} + +struct DropGuard { + shutdown_tx: watch::Sender<()>, +} + +impl Drop for DropGuard { + fn drop(&mut self) { + let _ = self.shutdown_tx.send(()); + } +} + +pub struct ModelManager { + pub(crate) registry: Arc>>, + pub(crate) default_model: Arc>>, + pub(crate) active: Arc>>>, + pub(crate) loading: Arc>>>, + pub(crate) last_activity: Arc>>, + pub(crate) inactivity_timeout: Duration, + _drop_guard: Arc, +} + +impl Clone for ModelManager { + fn clone(&self) -> Self { + Self { + registry: Arc::clone(&self.registry), + default_model: Arc::clone(&self.default_model), + active: Arc::clone(&self.active), + loading: Arc::clone(&self.loading), + last_activity: Arc::clone(&self.last_activity), + inactivity_timeout: self.inactivity_timeout, + _drop_guard: Arc::clone(&self._drop_guard), + } + } +} + +impl ModelManager { + pub fn builder() -> crate::ModelManagerBuilder { + crate::ModelManagerBuilder::default() + } + + pub async fn get_path(&self, name: &str) -> Option { + self.registry.read().await.get(name).cloned() + } + + pub async fn get_default_path(&self) -> Option { + let default = self.default_model.read().await; + let name = default.as_deref()?; + self.get_path(name).await + } + + pub async fn register(&self, name: impl Into, path: impl Into) { + let mut reg = self.registry.write().await; + reg.insert(name.into(), path.into()); + } + + pub async fn unregister(&self, name: &str) { + let mut reg = self.registry.write().await; + reg.remove(name); + + let mut active = self.active.lock().await; + if active.as_ref().is_some_and(|a| a.name == name) { + *active = None; + } + } + + pub async fn set_default(&self, name: impl Into) { + let mut default = self.default_model.write().await; + *default = Some(name.into()); + } + + pub async fn get(&self, name: Option<&str>) -> Result, Error> { + let resolved = match name { + Some(n) => n.to_string(), + None => { + let default = self.default_model.read().await; + default.clone().ok_or(Error::NoDefaultModel)? + } + }; + + let path = { + let reg = self.registry.read().await; + reg.get(&resolved) + .cloned() + .ok_or_else(|| Error::ModelNotRegistered(resolved.clone()))? + }; + + if !path.exists() { + return Err(Error::ModelFileNotFound(path.display().to_string())); + } + + self.update_activity().await; + + let mut active = self.active.lock().await; + + if let Some(ref a) = *active { + if a.name == resolved { + return Ok(Arc::clone(&a.model)); + } + } + + *active = None; + + let model = tokio::task::spawn_blocking(move || M::load(&path)) + .await + .map_err(|_| Error::WorkerPanicked)? + .map_err(Error::Load)?; + + let model = Arc::new(model); + *active = Some(ActiveModel { + name: resolved, + model: Arc::clone(&model), + }); + + Ok(model) + } + + /// Non-blocking model access. Returns the cached model if available, + /// kicks off a background load if not, and returns `Loading` while + /// the model is being built (e.g., CoreML compilation). + pub async fn try_get(&self, name: Option<&str>) -> TryGetResult { + let resolved: String = match name { + Some(n) => n.to_string(), + None => { + let default = self.default_model.read().await; + match default.clone() { + Some(n) => n, + None => return TryGetResult::NotRegistered, + } + } + }; + + let path: PathBuf = { + let reg = self.registry.read().await; + match reg.get(&resolved).cloned() { + Some(p) => p, + None => return TryGetResult::NotRegistered, + } + }; + + // Fast path: model already loaded + { + let active = self.active.lock().await; + if let Some(ref a) = *active { + if a.name == resolved { + self.update_activity().await; + return TryGetResult::Ready(Arc::clone(&a.model)); + } + } + } + + // Check/start background load + let mut loading_guard = self.loading.lock().await; + + // Check if a previous load completed or failed + if let Some(load_state) = loading_guard.take() { + match load_state { + LoadState::Loading(handle) if handle.is_finished() => match handle.await { + Ok(Ok(model)) => { + let model = Arc::new(model); + let mut active = self.active.lock().await; + *active = Some(ActiveModel { + name: resolved, + model: Arc::clone(&model), + }); + self.update_activity().await; + return TryGetResult::Ready(model); + } + Ok(Err(e)) => { + let msg = format!("{e}"); + *loading_guard = Some(LoadState::Failed(msg.clone())); + return TryGetResult::Failed(msg); + } + Err(_) => { + let msg = "worker task panicked".to_string(); + *loading_guard = Some(LoadState::Failed(msg.clone())); + return TryGetResult::Failed(msg); + } + }, + LoadState::Loading(handle) => { + *loading_guard = Some(LoadState::Loading(handle)); + return TryGetResult::Loading; + } + LoadState::Failed(msg) => { + *loading_guard = Some(LoadState::Failed(msg.clone())); + return TryGetResult::Failed(msg); + } + } + } + + // No active load — start one + if !path.exists() { + return TryGetResult::NotRegistered; + } + + self.update_activity().await; + let handle = tokio::task::spawn_blocking(move || M::load(&path)); + *loading_guard = Some(LoadState::Loading(handle)); + TryGetResult::Loading + } + + /// Passive model inspection. Unlike `try_get`, this does not start a load + /// and does not refresh activity for already-ready models. + pub async fn status(&self, name: Option<&str>) -> ModelStatus { + let resolved: String = match name { + Some(n) => n.to_string(), + None => { + let default = self.default_model.read().await; + match default.clone() { + Some(n) => n, + None => return ModelStatus::NotRegistered, + } + } + }; + + let path: PathBuf = { + let reg = self.registry.read().await; + match reg.get(&resolved).cloned() { + Some(p) => p, + None => return ModelStatus::NotRegistered, + } + }; + + { + let active = self.active.lock().await; + if let Some(ref a) = *active { + if a.name == resolved { + return ModelStatus::Ready(Arc::clone(&a.model)); + } + } + } + + let mut loading_guard = self.loading.lock().await; + if let Some(load_state) = loading_guard.take() { + match load_state { + LoadState::Loading(handle) if handle.is_finished() => match handle.await { + Ok(Ok(model)) => { + let model = Arc::new(model); + let mut active = self.active.lock().await; + *active = Some(ActiveModel { + name: resolved, + model: Arc::clone(&model), + }); + return ModelStatus::Ready(model); + } + Ok(Err(e)) => { + let msg = format!("{e}"); + *loading_guard = Some(LoadState::Failed(msg.clone())); + return ModelStatus::Failed(msg); + } + Err(_) => { + let msg = "worker task panicked".to_string(); + *loading_guard = Some(LoadState::Failed(msg.clone())); + return ModelStatus::Failed(msg); + } + }, + LoadState::Loading(handle) => { + *loading_guard = Some(LoadState::Loading(handle)); + return ModelStatus::Loading; + } + LoadState::Failed(msg) => { + *loading_guard = Some(LoadState::Failed(msg.clone())); + return ModelStatus::Failed(msg); + } + } + } + + if !path.exists() { + return ModelStatus::NotRegistered; + } + + ModelStatus::Idle + } + + async fn update_activity(&self) { + *self.last_activity.lock().await = Some(tokio::time::Instant::now()); + } + + pub(crate) fn spawn_monitor( + &self, + check_interval: Duration, + mut shutdown_rx: watch::Receiver<()>, + ) { + let active = Arc::clone(&self.active); + let loading = Arc::clone(&self.loading); + let last_activity = Arc::clone(&self.last_activity); + let inactivity_timeout = self.inactivity_timeout; + + tokio::spawn(async move { + let mut interval = tokio::time::interval(check_interval); + interval.tick().await; + + loop { + tokio::select! { + _ = shutdown_rx.changed() => break, + _ = interval.tick() => { + let last = last_activity.lock().await; + if let Some(t) = *last { + if t.elapsed() > inactivity_timeout { + *active.lock().await = None; + *loading.lock().await = None; + } + } + } + } + } + }); + } + + pub(crate) fn new( + registry: Arc>>, + default_model: Arc>>, + inactivity_timeout: Duration, + shutdown_tx: watch::Sender<()>, + ) -> Self { + Self { + registry, + default_model, + active: Arc::new(Mutex::new(None)), + loading: Arc::new(Mutex::new(None)), + last_activity: Arc::new(Mutex::new(None)), + inactivity_timeout, + _drop_guard: Arc::new(DropGuard { shutdown_tx }), + } + } +} diff --git a/crates/model-manager/src/tests.rs b/crates/model-manager/src/tests.rs new file mode 100644 index 0000000000..b03b0fabcd --- /dev/null +++ b/crates/model-manager/src/tests.rs @@ -0,0 +1,224 @@ +use std::{path::Path, path::PathBuf, sync::Arc, time::Duration}; + +use crate::{ModelLoader, ModelManager, ModelStatus, TryGetResult}; + +struct MockModel; + +impl ModelLoader for MockModel { + type Error = String; + + fn load(_path: &Path) -> Result { + Ok(MockModel) + } +} + +fn temp_model_path() -> PathBuf { + let dir = std::env::temp_dir().join("cactus-model-manager-tests"); + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join(format!("{}.bin", uuid::Uuid::new_v4())); + std::fs::write(&path, b"").unwrap(); + path +} + +fn build_manager( + timeout: Duration, + check_interval: Duration, + models: &[(&str, PathBuf)], +) -> ModelManager { + let mut builder = ModelManager::::builder() + .inactivity_timeout(timeout) + .check_interval(check_interval); + for (name, path) in models { + builder = builder.register(*name, path.clone()); + } + builder.build() +} + +async fn wait_for_try_get_ready(mgr: &ModelManager) { + for _ in 0..50 { + if matches!(mgr.try_get(None).await, TryGetResult::Ready(_)) { + return; + } + tokio::task::yield_now().await; + } + + panic!("model did not become ready via try_get"); +} + +async fn wait_for_status_ready(mgr: &ModelManager) { + for _ in 0..50 { + if matches!(mgr.status(None).await, ModelStatus::Ready(_)) { + return; + } + tokio::task::yield_now().await; + } + + panic!("model did not become ready via status"); +} + +#[tokio::test(start_paused = true)] +async fn idle_model_gets_evicted() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + + let m1 = mgr.get(Some("a")).await.unwrap(); + let m2 = mgr.get(Some("a")).await.unwrap(); + assert!(Arc::ptr_eq(&m1, &m2)); + + tokio::time::advance(Duration::from_millis(120)).await; + tokio::task::yield_now().await; + + let m3 = mgr.get(Some("a")).await.unwrap(); + assert!(!Arc::ptr_eq(&m1, &m3)); +} + +#[tokio::test(start_paused = true)] +async fn activity_prevents_eviction() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + + let m1 = mgr.get(Some("a")).await.unwrap(); + + for _ in 0..5 { + tokio::time::advance(Duration::from_millis(50)).await; + tokio::task::yield_now().await; + + let m = mgr.get(Some("a")).await.unwrap(); + assert!(Arc::ptr_eq(&m1, &m)); + } +} + +#[tokio::test(start_paused = true)] +async fn access_near_timeout_resets_timer() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + + let m1 = mgr.get(Some("a")).await.unwrap(); + + tokio::time::advance(Duration::from_millis(90)).await; + tokio::task::yield_now().await; + + let m2 = mgr.get(Some("a")).await.unwrap(); + assert!(Arc::ptr_eq(&m1, &m2)); + + tokio::time::advance(Duration::from_millis(50)).await; + tokio::task::yield_now().await; + + let m3 = mgr.get(Some("a")).await.unwrap(); + assert!(Arc::ptr_eq(&m1, &m3)); +} + +#[tokio::test(start_paused = true)] +async fn try_get_returns_loading_then_ready() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + mgr.set_default("a").await; + + // First try_get should start loading + let result = mgr.try_get(None).await; + assert!(matches!(result, TryGetResult::Loading)); + + wait_for_try_get_ready(&mgr).await; +} + +#[tokio::test(start_paused = true)] +async fn try_get_not_registered() { + let mgr = build_manager(Duration::from_millis(100), Duration::from_millis(10), &[]); + + let result = mgr.try_get(Some("nonexistent")).await; + assert!(matches!(result, TryGetResult::NotRegistered)); +} + +#[tokio::test(start_paused = true)] +async fn try_get_after_eviction_reloads() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + mgr.set_default("a").await; + + // Load via blocking get + let _m1 = mgr.get(Some("a")).await.unwrap(); + + // Evict + tokio::time::advance(Duration::from_millis(120)).await; + tokio::task::yield_now().await; + + // try_get should trigger reload + let result = mgr.try_get(None).await; + assert!(matches!(result, TryGetResult::Loading)); +} + +#[tokio::test(start_paused = true)] +async fn status_is_idle_before_first_request() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + mgr.set_default("a").await; + + let result = mgr.status(None).await; + assert!(matches!(result, ModelStatus::Idle)); +} + +#[tokio::test(start_paused = true)] +async fn status_reports_loading_without_starting_load() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + mgr.set_default("a").await; + + assert!(matches!(mgr.status(None).await, ModelStatus::Idle)); + assert!(matches!(mgr.status(None).await, ModelStatus::Idle)); + + let result = mgr.try_get(None).await; + assert!(matches!(result, TryGetResult::Loading)); + assert!(matches!(mgr.status(None).await, ModelStatus::Loading)); +} + +#[tokio::test(start_paused = true)] +async fn status_promotes_completed_load_without_refreshing_activity() { + let path = temp_model_path(); + let mgr = build_manager( + Duration::from_millis(100), + Duration::from_millis(10), + &[("a", path)], + ); + mgr.set_default("a").await; + + let result = mgr.try_get(None).await; + assert!(matches!(result, TryGetResult::Loading)); + + wait_for_status_ready(&mgr).await; + + tokio::time::advance(Duration::from_millis(90)).await; + tokio::task::yield_now().await; + assert!(matches!(mgr.status(None).await, ModelStatus::Ready(_))); + + tokio::time::advance(Duration::from_millis(20)).await; + tokio::task::yield_now().await; + assert!(matches!(mgr.status(None).await, ModelStatus::Idle)); +} diff --git a/crates/owhisper-client/Cargo.toml b/crates/owhisper-client/Cargo.toml index 7da861032b..45b5ee243c 100644 --- a/crates/owhisper-client/Cargo.toml +++ b/crates/owhisper-client/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [features] default = [] -argmax = ["hypr-audio-utils"] +local = ["hypr-audio-utils"] [dependencies] hypr-am = { workspace = true } @@ -25,6 +25,7 @@ reqwest-tracing = { workspace = true } tokio = { workspace = true, features = ["fs"] } tokio-stream = { workspace = true } +backon = { workspace = true } base64 = { workspace = true } bytes = { workspace = true } serde = { workspace = true } diff --git a/crates/owhisper-client/src/adapter/argmax/mod.rs b/crates/owhisper-client/src/adapter/argmax/mod.rs index 422a30a74b..7a0204e7ed 100644 --- a/crates/owhisper-client/src/adapter/argmax/mod.rs +++ b/crates/owhisper-client/src/adapter/argmax/mod.rs @@ -1,10 +1,10 @@ -#[cfg(feature = "argmax")] +#[cfg(feature = "local")] mod batch; pub(crate) mod keywords; pub(crate) mod language; mod live; -#[cfg(feature = "argmax")] +#[cfg(feature = "local")] pub use batch::StreamingBatchConfig; pub use language::PARAKEET_V3_LANGS; diff --git a/crates/owhisper-client/src/adapter/cactus/batch.rs b/crates/owhisper-client/src/adapter/cactus/batch.rs index 02813b67df..561e8c191c 100644 --- a/crates/owhisper-client/src/adapter/cactus/batch.rs +++ b/crates/owhisper-client/src/adapter/cactus/batch.rs @@ -30,26 +30,10 @@ impl CactusAdapter { .map_err(|e| Error::AudioProcessing(format!("task panicked: {:?}", e)))??; let url = build_cactus_batch_url(api_base, params); - let client = reqwest::Client::new(); - let response = client - .post(url) - .header("Content-Type", &content_type) - .header("Accept", "text/event-stream") - .body(audio_data) - .send() - .await?; - - let status = response.status(); - if !status.is_success() { - let body = response.text().await.unwrap_or_default(); - tracing::error!( - http.response.status_code = status.as_u16(), - hyprnote.http.response.body = %body, - "unexpected_response_status" - ); - return Err(Error::UnexpectedStatus { status, body }); - } + + let response = + super::retry::post_with_retry(&client, url, &content_type, audio_data).await?; let byte_stream = response.bytes_stream(); diff --git a/crates/owhisper-client/src/adapter/cactus/mod.rs b/crates/owhisper-client/src/adapter/cactus/mod.rs index b8a1bd9fd4..89080dcd7e 100644 --- a/crates/owhisper-client/src/adapter/cactus/mod.rs +++ b/crates/owhisper-client/src/adapter/cactus/mod.rs @@ -1,6 +1,8 @@ -#[cfg(feature = "argmax")] +#[cfg(feature = "local")] mod batch; mod live; +#[cfg(feature = "local")] +mod retry; #[derive(Clone, Default)] pub struct CactusAdapter; diff --git a/crates/owhisper-client/src/adapter/cactus/retry.rs b/crates/owhisper-client/src/adapter/cactus/retry.rs new file mode 100644 index 0000000000..2c2ec324fd --- /dev/null +++ b/crates/owhisper-client/src/adapter/cactus/retry.rs @@ -0,0 +1,89 @@ +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Duration; + +use backon::{ConstantBuilder, Retryable}; +use reqwest::StatusCode; + +use crate::error::Error; + +const MAX_RETRIES: usize = 14; +const DEFAULT_RETRY_DELAY: Duration = Duration::from_secs(5); +const NO_RETRY_AFTER: u64 = u64::MAX; + +pub(super) async fn post_with_retry( + client: &reqwest::Client, + url: url::Url, + content_type: &str, + audio_data: Vec, +) -> Result { + let retry_after_secs = AtomicU64::new(NO_RETRY_AFTER); + + let result = (|| { + let url = url.clone(); + let audio_data = audio_data.clone(); + async { + let resp = client + .post(url) + .header("Content-Type", content_type) + .header("Accept", "text/event-stream") + .body(audio_data) + .send() + .await?; + + if resp.status() == StatusCode::SERVICE_UNAVAILABLE { + let secs = resp + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()) + .unwrap_or(NO_RETRY_AFTER); + retry_after_secs.store(secs, Ordering::SeqCst); + + let body = resp.text().await.unwrap_or_default(); + return Err(Error::UnexpectedStatus { + status: StatusCode::SERVICE_UNAVAILABLE, + body, + }); + } + + Ok(resp) + } + }) + .retry( + ConstantBuilder::default() + .with_delay(DEFAULT_RETRY_DELAY) + .with_max_times(MAX_RETRIES), + ) + .when(|e| { + matches!( + e, + Error::UnexpectedStatus { status, .. } if *status == StatusCode::SERVICE_UNAVAILABLE + ) + }) + .adjust(|_err, dur| { + let secs = retry_after_secs.swap(NO_RETRY_AFTER, Ordering::SeqCst); + if secs == NO_RETRY_AFTER { + dur + } else { + Some(Duration::from_secs(secs)) + } + }) + .notify(|_err, dur| { + tracing::info!(retry_after_secs = dur.as_secs(), "model_loading_retry"); + }) + .await; + + let response = result?; + let status = response.status(); + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + tracing::error!( + http.response.status_code = status.as_u16(), + hyprnote.http.response.body = %body, + "unexpected_response_status" + ); + return Err(Error::UnexpectedStatus { status, body }); + } + + Ok(response) +} diff --git a/crates/owhisper-client/src/lib.rs b/crates/owhisper-client/src/lib.rs index 06ff619213..81e29836de 100644 --- a/crates/owhisper-client/src/lib.rs +++ b/crates/owhisper-client/src/lib.rs @@ -15,7 +15,7 @@ pub use providers::{Auth, Provider, is_meta_model}; use std::marker::PhantomData; -#[cfg(feature = "argmax")] +#[cfg(feature = "local")] pub use adapter::StreamingBatchConfig; pub use adapter::deepgram::DeepgramModel; pub use adapter::{ diff --git a/crates/transcribe-cactus/Cargo.toml b/crates/transcribe-cactus/Cargo.toml index 00ea1d3583..e7469fa7c9 100644 --- a/crates/transcribe-cactus/Cargo.toml +++ b/crates/transcribe-cactus/Cargo.toml @@ -5,8 +5,9 @@ edition = "2024" [dependencies] hypr-audio-utils = { workspace = true } -hypr-cactus = { workspace = true } +hypr-cactus = { workspace = true, features = ["model-manager"] } hypr-language = { workspace = true } +hypr-model-manager = { workspace = true } hypr-vad-chunking = { workspace = true } hypr-ws-utils = { workspace = true } owhisper-interface = { workspace = true } diff --git a/crates/transcribe-cactus/src/service/batch/mod.rs b/crates/transcribe-cactus/src/service/batch/mod.rs index 970a484b8c..7bd4097836 100644 --- a/crates/transcribe-cactus/src/service/batch/mod.rs +++ b/crates/transcribe-cactus/src/service/batch/mod.rs @@ -3,7 +3,7 @@ mod response; mod transcribe; use std::convert::Infallible; -use std::path::Path; +use std::sync::Arc; use axum::{ Json, @@ -24,15 +24,14 @@ pub async fn handle_batch( body: Bytes, content_type: &str, params: &ListenParams, - model_path: &Path, + model: Arc, ) -> Response { - let model_path = model_path.to_path_buf(); let content_type = content_type.to_string(); let params = params.clone(); let result = tokio::task::spawn_blocking(move || { std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - transcribe_batch(&body, &content_type, ¶ms, &model_path, None) + transcribe_batch(&body, &content_type, ¶ms, &model, None) })) }) .await; @@ -68,9 +67,8 @@ pub async fn handle_batch_sse( body: Bytes, content_type: &str, params: &ListenParams, - model_path: &Path, + model: Arc, ) -> Response { - let model_path = model_path.to_path_buf(); let content_type = content_type.to_string(); let params = params.clone(); @@ -82,7 +80,7 @@ pub async fn handle_batch_sse( &body, &content_type, ¶ms, - &model_path, + &model, Some(event_tx.clone()), ) })) { diff --git a/crates/transcribe-cactus/src/service/batch/transcribe.rs b/crates/transcribe-cactus/src/service/batch/transcribe.rs index d22027f429..c53ce9463c 100644 --- a/crates/transcribe-cactus/src/service/batch/transcribe.rs +++ b/crates/transcribe-cactus/src/service/batch/transcribe.rs @@ -1,5 +1,5 @@ use std::io::Write; -use std::path::Path; +use std::sync::Arc; use owhisper_interface::ListenParams; use owhisper_interface::batch; @@ -13,18 +13,17 @@ use super::response::{build_batch_words, build_segment_stream_response}; use hypr_audio_utils::content_type_to_extension; #[tracing::instrument( - skip(audio_data, event_tx), + skip(audio_data, model, event_tx), fields( hyprnote.audio.size_bytes = audio_data.len(), hyprnote.file.mime_type = content_type, - hyprnote.model.path = %model_path.display() ) )] pub(super) fn transcribe_batch( audio_data: &[u8], content_type: &str, params: &ListenParams, - model_path: &Path, + model: &Arc, event_tx: Option>, ) -> Result { let extension = content_type_to_extension(content_type); @@ -45,17 +44,9 @@ pub(super) fn transcribe_batch( .map(|samples| channel_duration_sec(samples)) .fold(0.0_f64, f64::max); - let model = match crate::service::build_model(model_path) { - Ok(m) => m, - Err(e) => { - tracing::error!(error = %e, "failed_to_load_model"); - return Err(e.into()); - } - }; - let options = crate::service::build_transcribe_options(params, None); - let metadata = crate::service::build_metadata(model_path); + let metadata = owhisper_interface::stream::Metadata::default(); let channel_durations = channel_samples .iter() .map(|samples| channel_duration_sec(samples)) @@ -378,8 +369,6 @@ impl ProgressTracker { #[cfg(test)] mod tests { - use std::path::Path; - use hypr_language::ISO639; use owhisper_interface::ListenParams; @@ -490,13 +479,18 @@ mod tests { .to_string_lossy() .into_owned() }); - let model_path = Path::new(&model_path_str); + let model_path = std::path::Path::new(&model_path_str); assert!( model_path.exists(), "model path does not exist: {}", model_path.display() ); + let model = Arc::new( + crate::service::build_model(model_path) + .unwrap_or_else(|e| panic!("failed to build model: {e}")), + ); + let wav_bytes = std::fs::read(hypr_data::english_1::AUDIO_PATH) .unwrap_or_else(|e| panic!("failed to read fixture wav: {e}")); @@ -505,7 +499,7 @@ mod tests { ..Default::default() }; - let response = transcribe_batch(&wav_bytes, "audio/wav", ¶ms, model_path, None) + let response = transcribe_batch(&wav_bytes, "audio/wav", ¶ms, &model, None) .unwrap_or_else(|e| panic!("real-model batch transcription failed: {e}")); let Some(channel) = response.results.channels.first() else { diff --git a/crates/transcribe-cactus/src/service/mod.rs b/crates/transcribe-cactus/src/service/mod.rs index 5aeb2650ee..3ae1a19bc2 100644 --- a/crates/transcribe-cactus/src/service/mod.rs +++ b/crates/transcribe-cactus/src/service/mod.rs @@ -13,9 +13,9 @@ pub(crate) struct Segment<'a> { pub confidence: f64, } -pub(crate) fn build_metadata(model_path: &Path) -> Metadata { +pub(crate) fn build_metadata(model_path: Option<&Path>) -> Metadata { let model_name = model_path - .file_stem() + .and_then(|p| p.file_stem()) .and_then(|s| s.to_str()) .unwrap_or("cactus") .to_string(); @@ -31,6 +31,7 @@ pub(crate) fn build_metadata(model_path: &Path) -> Metadata { } } +#[allow(dead_code)] pub(crate) fn build_model(model_path: &Path) -> Result { static LOG_INIT: std::sync::Once = std::sync::Once::new(); LOG_INIT.call_once(hypr_cactus::log::init); diff --git a/crates/transcribe-cactus/src/service/streaming/service.rs b/crates/transcribe-cactus/src/service/streaming/service.rs index 879315197a..a72ccae6a6 100644 --- a/crates/transcribe-cactus/src/service/streaming/service.rs +++ b/crates/transcribe-cactus/src/service/streaming/service.rs @@ -1,6 +1,5 @@ use std::{ future::Future, - path::PathBuf, pin::Pin, task::{Context, Poll}, }; @@ -8,11 +7,12 @@ use std::{ use axum::{ body::Body, extract::{FromRequestParts, ws::WebSocketUpgrade}, - http::{Request, StatusCode}, + http::{Request, StatusCode, header}, response::{IntoResponse, Response}, }; use tower::Service; +use hypr_model_manager::{ModelManager, TryGetResult}; use hypr_ws_utils::ConnectionManager; use owhisper_interface::ListenParams; @@ -20,9 +20,11 @@ use super::super::batch; use super::session; use crate::CactusConfig; +const MODEL_NAME: &str = "default"; + #[derive(Clone)] pub struct TranscribeService { - model_path: PathBuf, + model_manager: ModelManager, cactus_config: CactusConfig, connection_manager: ConnectionManager, } @@ -31,21 +33,33 @@ impl TranscribeService { pub fn builder() -> TranscribeServiceBuilder { TranscribeServiceBuilder::default() } + + pub fn model_manager(&self) -> &ModelManager { + &self.model_manager + } } #[derive(Default)] pub struct TranscribeServiceBuilder { - model_path: Option, + model_manager: Option>, cactus_config: CactusConfig, connection_manager: Option, } impl TranscribeServiceBuilder { - pub fn model_path(mut self, model_path: PathBuf) -> Self { - self.model_path = Some(model_path); + pub fn model_manager(mut self, model_manager: ModelManager) -> Self { + self.model_manager = Some(model_manager); self } + pub fn model_path(self, model_path: std::path::PathBuf) -> Self { + let model_manager = ModelManager::::builder() + .register(MODEL_NAME, &model_path) + .default_model(MODEL_NAME) + .build(); + self.model_manager(model_manager) + } + pub fn cactus_config(mut self, config: CactusConfig) -> Self { self.cactus_config = config; self @@ -53,9 +67,9 @@ impl TranscribeServiceBuilder { pub fn build(self) -> TranscribeService { TranscribeService { - model_path: self - .model_path - .expect("TranscribeServiceBuilder requires model_path"), + model_manager: self + .model_manager + .expect("TranscribeServiceBuilder requires model_manager"), cactus_config: self.cactus_config, connection_manager: self.connection_manager.unwrap_or_default(), } @@ -72,7 +86,7 @@ impl Service> for TranscribeService { } fn call(&mut self, req: Request) -> Self::Future { - let model_path = self.model_path.clone(); + let model_manager = self.model_manager.clone(); let cactus_config = self.cactus_config.clone(); let connection_manager = self.connection_manager.clone(); @@ -92,19 +106,37 @@ impl Service> for TranscribeService { } }; + // Non-blocking model check — return 503 if still loading + let model = match model_manager.try_get(None).await { + TryGetResult::Ready(model) => model, + TryGetResult::Loading => { + return Ok(Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .header(header::RETRY_AFTER, "1") + .body(Body::from("model is loading")) + .unwrap()); + } + TryGetResult::Failed(msg) => { + tracing::error!(error = %msg, "model_load_failed"); + return Ok(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to load model: {msg}"), + ) + .into_response()); + } + TryGetResult::NotRegistered => { + return Ok(( + StatusCode::INTERNAL_SERVER_ERROR, + "model not registered".to_string(), + ) + .into_response()); + } + }; + + let model_path = model_manager.get_default_path().await; + let metadata = crate::service::build_metadata(model_path.as_deref()); + if is_ws { - let model = match crate::service::build_model(&model_path) { - Ok(model) => std::sync::Arc::new(model), - Err(error) => { - tracing::error!(error = %error, "failed_to_load_model"); - return Ok(( - StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to load model: {error}"), - ) - .into_response()); - } - }; - let metadata = crate::service::build_metadata(&model_path); let (mut parts, _body) = req.into_parts(); let ws_upgrade = match WebSocketUpgrade::from_request_parts(&mut parts, &()).await { Ok(ws) => ws, @@ -156,12 +188,9 @@ impl Service> for TranscribeService { } if accept.contains("text/event-stream") { - Ok( - batch::handle_batch_sse(body_bytes, &content_type, ¶ms, &model_path) - .await, - ) + Ok(batch::handle_batch_sse(body_bytes, &content_type, ¶ms, model).await) } else { - Ok(batch::handle_batch(body_bytes, &content_type, ¶ms, &model_path).await) + Ok(batch::handle_batch(body_bytes, &content_type, ¶ms, model).await) } } }) diff --git a/crates/transcript/src/segments/collect.rs b/crates/transcript/src/segments/collect.rs index cd4b40e071..deb482f75c 100644 --- a/crates/transcript/src/segments/collect.rs +++ b/crates/transcript/src/segments/collect.rs @@ -9,7 +9,7 @@ pub(super) fn collect_segments( frames: Vec, options: Option<&SegmentBuilderOptions>, ) -> Vec { - let max_gap_ms = options.and_then(|opts| opts.max_gap_ms).unwrap_or(2000); + let max_gap_ms = options.and_then(|opts| opts.max_gap_ms).unwrap_or(3000); let mut segments: Vec = Vec::new(); let mut last_segment_by_channel: HashMap = HashMap::new(); diff --git a/crates/transcript/src/words/stitch.rs b/crates/transcript/src/words/stitch.rs index 606942624d..64a8e1c1ec 100644 --- a/crates/transcript/src/words/stitch.rs +++ b/crates/transcript/src/words/stitch.rs @@ -30,7 +30,19 @@ pub(crate) fn stitch( const STITCH_MAX_GAP_MS: i64 = 300; fn should_stitch(tail: &RawWord, head: &RawWord) -> bool { - !head.text.starts_with(' ') && (head.start_ms - tail.end_ms) <= STITCH_MAX_GAP_MS + !head.text.starts_with(' ') + && (head.start_ms - tail.end_ms) <= STITCH_MAX_GAP_MS + && !is_sentence_boundary(&tail.text, &head.text) +} + +fn is_sentence_boundary(tail_text: &str, head_text: &str) -> bool { + let ends_with_terminal = tail_text + .trim_end() + .ends_with(|c: char| matches!(c, '.' | '!' | '?')); + let starts_with_upper = head_text + .trim_start() + .starts_with(|c: char| c.is_uppercase()); + ends_with_terminal && starts_with_upper } fn merge_words(mut left: RawWord, right: RawWord) -> RawWord { @@ -83,4 +95,32 @@ mod tests { let head = word(",", 500, 510); assert!(!should_stitch(&tail, &head)); } + + #[test] + fn does_not_stitch_across_sentence_boundary() { + let tail = word(" you.", 0, 100); + let head = word("Ultimately", 100, 200); + assert!(!should_stitch(&tail, &head)); + } + + #[test] + fn does_not_stitch_across_exclamation_boundary() { + let tail = word(" great!", 0, 100); + let head = word("Thank", 100, 200); + assert!(!should_stitch(&tail, &head)); + } + + #[test] + fn still_stitches_abbreviation_continuation() { + let tail = word(" example.", 0, 100); + let head = word("com", 100, 200); + assert!(should_stitch(&tail, &head)); + } + + #[test] + fn still_stitches_decimal_after_period() { + let tail = word(" 3.", 0, 100); + let head = word("14", 100, 200); + assert!(should_stitch(&tail, &head)); + } } diff --git a/crates/ws-client/Cargo.toml b/crates/ws-client/Cargo.toml index e7a24f8fa9..ef62f1ec25 100644 --- a/crates/ws-client/Cargo.toml +++ b/crates/ws-client/Cargo.toml @@ -9,6 +9,7 @@ serde = { workspace = true, features = ["derive"] } thiserror = { workspace = true } async-stream = { workspace = true } +backon = { workspace = true } futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "time", "sync", "macros"] } tokio-tungstenite = { workspace = true, features = ["native-tls"] } diff --git a/crates/ws-client/src/error.rs b/crates/ws-client/src/error.rs index 7d057f5bf0..db8ee8271f 100644 --- a/crates/ws-client/src/error.rs +++ b/crates/ws-client/src/error.rs @@ -14,6 +14,7 @@ pub enum Error { is_auth: bool, status_code: Option, retryable: bool, + retry_after_secs: Option, }, #[error("connect retries exhausted after {attempts} attempts: {last_error}")] ConnectRetriesExhausted { attempts: usize, last_error: String }, @@ -62,10 +63,11 @@ impl std::fmt::Debug for Error { is_auth, status_code, retryable, + retry_after_secs, } => { write!( f, - "ConnectFailed({attempt}/{max_attempts}, auth={is_auth}, status={status_code:?}, retryable={retryable}, {message})" + "ConnectFailed({attempt}/{max_attempts}, auth={is_auth}, status={status_code:?}, retryable={retryable}, retry_after={retry_after_secs:?}, {message})" ) } Error::ConnectRetriesExhausted { @@ -151,6 +153,7 @@ impl Error { is_auth: is_http_auth_error(error), status_code: http_status(error), retryable: is_retryable_handshake_error(error), + retry_after_secs: extract_retry_after(error), } } @@ -208,6 +211,17 @@ fn http_status(error: &tokio_tungstenite::tungstenite::Error) -> Option { None } +fn extract_retry_after(error: &tokio_tungstenite::tungstenite::Error) -> Option { + if let tokio_tungstenite::tungstenite::Error::Http(response) = error { + return response + .headers() + .get("retry-after") + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + } + None +} + fn is_retryable_http_status(status: u16) -> bool { matches!(status, 408 | 425 | 429) || (500..=599).contains(&status) } diff --git a/crates/ws-client/src/retry.rs b/crates/ws-client/src/retry.rs index 3f5c15b0d4..b878acc605 100644 --- a/crates/ws-client/src/retry.rs +++ b/crates/ws-client/src/retry.rs @@ -1,20 +1,24 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use backon::{ConstantBuilder, Retryable}; use tokio_tungstenite::{connect_async, tungstenite::client::IntoClientRequest}; pub type WebSocketRetryCallback = std::sync::Arc; #[derive(Debug, Clone)] pub struct WebSocketConnectPolicy { - pub connect_timeout: std::time::Duration, + pub connect_timeout: Duration, pub max_attempts: usize, - pub retry_delay: std::time::Duration, + pub retry_delay: Duration, } impl Default for WebSocketConnectPolicy { fn default() -> Self { Self { - connect_timeout: std::time::Duration::from_secs(5), + connect_timeout: Duration::from_secs(5), max_attempts: 3, - retry_delay: std::time::Duration::from_millis(750), + retry_delay: Duration::from_millis(750), } } } @@ -35,62 +39,59 @@ pub(crate) async fn connect_with_retry( crate::Error, > { let max_attempts = policy.max_attempts.max(1); - let mut attempts_made = 0usize; - let mut last_error: Option = None; - - for attempt in 1..=max_attempts { - attempts_made = attempt; - match try_connect( - request.clone(), - policy.connect_timeout, - attempt, - max_attempts, - ) - .await - { - Ok(stream) => return Ok(stream), - Err(error) => { - tracing::error!("ws_connect_failed: {:?}", error); - - if !error.is_retryable_connect_error() { - return Err(error); - } - - if attempt >= max_attempts { - last_error = Some(error); - break; - } - - if let Some(callback) = on_retry { - callback(WebSocketRetryEvent { - attempt: attempt + 1, - max_attempts, - error: error.to_string(), - }); - } + let attempt_count = AtomicUsize::new(0); - last_error = Some(error); - tokio::time::sleep(policy.retry_delay).await; + let result = (|| { + let request = request.clone(); + async { + let attempt = attempt_count.fetch_add(1, Ordering::SeqCst) + 1; + try_connect(request, policy.connect_timeout, attempt, max_attempts).await + } + }) + .retry( + ConstantBuilder::default() + .with_delay(policy.retry_delay) + .with_max_times(max_attempts - 1), + ) + .when(|e: &crate::Error| e.is_retryable_connect_error()) + .adjust(|err, dur| match err { + crate::Error::ConnectFailed { + retry_after_secs: Some(secs), + .. + } => Some(Duration::from_secs(*secs)), + _ => dur, + }) + .notify(|err, _dur| { + tracing::error!("ws_connect_failed: {:?}", err); + if let Some(callback) = on_retry { + callback(WebSocketRetryEvent { + attempt: attempt_count.load(Ordering::SeqCst) + 1, + max_attempts, + error: err.to_string(), + }); + } + }) + .await; + + match result { + Ok(stream) => Ok(stream), + Err(error) => { + let attempts = attempt_count.load(Ordering::SeqCst); + if attempts >= max_attempts { + Err(crate::Error::connect_retries_exhausted( + attempts, + error.to_string(), + )) + } else { + Err(error) } } } - - match last_error { - Some(error @ crate::Error::ConnectRetriesExhausted { .. }) => Err(error), - Some(error) => Err(crate::Error::connect_retries_exhausted( - attempts_made, - error.to_string(), - )), - None => Err(crate::Error::connect_retries_exhausted( - attempts_made, - "connect failed", - )), - } } async fn try_connect( req: tokio_tungstenite::tungstenite::ClientRequestBuilder, - timeout: std::time::Duration, + timeout: Duration, attempt: usize, max_attempts: usize, ) -> Result< @@ -101,7 +102,6 @@ async fn try_connect( .into_client_request() .map_err(|error| crate::Error::invalid_request(error.to_string()))?; - // AWS WAF and similar firewalls reject WebSocket upgrades without a User-Agent. if !req.headers().contains_key("user-agent") { req.headers_mut().insert( "user-agent", diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index b79724f20f..3a7bd28c53 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -85,4 +85,6 @@ port-killer = "0.1.0" port_check = "0.3.0" [target.'cfg(target_arch = "aarch64")'.dependencies] +hypr-cactus = { workspace = true, features = ["model-manager"] } +hypr-model-manager = { workspace = true } hypr-transcribe-cactus = { workspace = true } diff --git a/plugins/local-stt/src/server/internal2.rs b/plugins/local-stt/src/server/internal2.rs index f28fa83d03..190a56167b 100644 --- a/plugins/local-stt/src/server/internal2.rs +++ b/plugins/local-stt/src/server/internal2.rs @@ -4,6 +4,7 @@ use std::{ }; use axum::{Router, error_handling::HandleError}; +use hypr_model_manager::{ModelManager, ModelStatus}; use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; use reqwest::StatusCode; use tower_http::cors::{self, CorsLayer}; @@ -26,6 +27,7 @@ pub struct Internal2STTArgs { pub struct Internal2STTState { base_url: String, model: CactusSttModel, + model_manager: ModelManager, shutdown: tokio::sync::watch::Sender<()>, server_task: tokio::task::JoinHandle<()>, } @@ -59,16 +61,16 @@ impl Actor for Internal2STTActor { tracing::info!(model_path = %model_path.display(), "starting internal2 STT server"); - let cactus_service = HandleError::new( - hypr_transcribe_cactus::TranscribeService::builder() - .model_path(model_path) - .cactus_config(cactus_config) - .build(), - move |err: String| async move { - let _ = myself.send_message(Internal2STTMessage::ServerError(err.clone())); - (StatusCode::INTERNAL_SERVER_ERROR, err) - }, - ); + let transcribe_service = hypr_transcribe_cactus::TranscribeService::builder() + .model_path(model_path) + .cactus_config(cactus_config) + .build(); + let model_manager = transcribe_service.model_manager().clone(); + + let cactus_service = HandleError::new(transcribe_service, move |err: String| async move { + let _ = myself.send_message(Internal2STTMessage::ServerError(err.clone())); + (StatusCode::INTERNAL_SERVER_ERROR, err) + }); let router = Router::new() .route_service("/v1/listen", cactus_service) @@ -99,6 +101,7 @@ impl Actor for Internal2STTActor { Ok(Internal2STTState { base_url, model: model_type, + model_manager, shutdown: shutdown_tx, server_task, }) @@ -123,9 +126,17 @@ impl Actor for Internal2STTActor { match message { Internal2STTMessage::ServerError(e) => Err(e.into()), Internal2STTMessage::GetHealth(reply_port) => { + let status = match state.model_manager.status(None).await { + ModelStatus::Ready(_) | ModelStatus::Idle => ServerStatus::Ready, + ModelStatus::Loading => ServerStatus::Loading, + ModelStatus::Failed(_) | ModelStatus::NotRegistered => { + ServerStatus::Unreachable + } + }; + let info = ServerInfo { url: Some(state.base_url.clone()), - status: ServerStatus::Ready, + status, model: Some(crate::LocalModel::Cactus(state.model.clone())), }; diff --git a/plugins/permissions/src/ext.rs b/plugins/permissions/src/ext.rs index 072fbb5c69..881050e4d0 100644 --- a/plugins/permissions/src/ext.rs +++ b/plugins/permissions/src/ext.rs @@ -45,8 +45,7 @@ impl<'a, R: tauri::Runtime, M: tauri::Manager> Permissions<'a, R, M> { } fn require_audio(&self) -> Result, crate::Error> { - self.audio_provider() - .ok_or(crate::Error::NoAudioProvider) + self.audio_provider().ok_or(crate::Error::NoAudioProvider) } pub async fn open(&self, permission: Permission) -> Result<(), crate::Error> {