diff --git a/src/acp/pool.rs b/src/acp/pool.rs index e1d27bf..0121e24 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -1,7 +1,8 @@ use crate::acp::connection::AcpConnection; use crate::acp::protocol::ConfigOption; +use crate::adapter::{ChannelRef, ChatAdapter}; use crate::config::AgentConfig; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, bail, Result}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::{Mutex, RwLock}; @@ -16,6 +17,11 @@ struct PoolState { /// Lock-free cancel handles: thread_key → (stdin, session_id). /// Stored separately so cancel can work without locking the connection. cancel_handles: HashMap>, String)>, + /// Addressing info for each active thread. Populated alongside `active` + /// and pruned together. Used by `begin_shutdown` so the broker can post + /// a notification to every live session without the adapter layer + /// maintaining a parallel cache. Invariant: `addresses.keys() == active.keys()`. + addresses: HashMap)>, /// Suspended sessions: thread_key → ACP sessionId. /// Saved on eviction so sessions can be resumed via `session/load`. suspended: HashMap, @@ -28,6 +34,10 @@ pub struct SessionPool { state: RwLock, config: AgentConfig, max_sessions: usize, + /// Flipped by `begin_shutdown` to reject new admissions. Checked inside + /// `get_or_create` under the state write lock so admission and snapshot + /// are atomic. + shutting_down: std::sync::atomic::AtomicBool, } type EvictionCandidate = ( @@ -67,15 +77,53 @@ impl SessionPool { state: RwLock::new(PoolState { active: HashMap::new(), cancel_handles: HashMap::new(), + addresses: HashMap::new(), suspended: HashMap::new(), creating: HashMap::new(), }), config, max_sessions, + shutting_down: std::sync::atomic::AtomicBool::new(false), } } - pub async fn get_or_create(&self, thread_id: &str) -> Result<()> { + /// True once `begin_shutdown` has been called. Router uses this to show a + /// shutdown-specific message instead of a generic pool error when + /// `get_or_create` rejects admission. + pub fn is_shutting_down(&self) -> bool { + self.shutting_down + .load(std::sync::atomic::Ordering::Acquire) + } + + /// Flip the pool into shutting-down state and return a snapshot of every + /// live session's addressing info. Takes the state write lock so the + /// snapshot is atomic with respect to in-flight `get_or_create` calls: + /// any admission that committed before us is included; any that comes + /// after us sees the flag inside the same lock and rejects. + pub async fn begin_shutdown(&self) -> Vec<(String, ChannelRef, Arc)> { + let state = self.state.write().await; + self.shutting_down + .store(true, std::sync::atomic::Ordering::Release); + state + .addresses + .iter() + .map(|(k, (c, a))| (k.clone(), c.clone(), a.clone())) + .collect() + } + + pub async fn get_or_create( + &self, + thread_id: &str, + channel: &ChannelRef, + adapter: &Arc, + ) -> Result<()> { + // Fast-fail: avoid spawning a fresh ACP process if shutdown is already + // in progress. The authoritative check happens again under the state + // write lock below so we also catch shutdowns that start mid-spawn. + if self.is_shutting_down() { + bail!("pool is shutting down"); + } + let create_gate = { let mut state = self.state.write().await; get_or_insert_gate(&mut state.creating, thread_id) @@ -84,6 +132,9 @@ impl SessionPool { let (existing, saved_session_id) = { let state = self.state.read().await; + if self.is_shutting_down() { + bail!("pool is shutting down"); + } ( state.active.get(thread_id).cloned(), state.suspended.get(thread_id).cloned(), @@ -95,6 +146,16 @@ impl SessionPool { if let Some(conn) = existing.clone() { let conn = conn.lock().await; if conn.alive() { + // Re-check shutdown state after waiting on the per-connection + // mutex. Taking `state.read()` synchronizes us with + // `begin_shutdown`'s write-lock flag flip, so the flag value + // we see here reflects every `begin_shutdown` that has + // committed. This closes the race where shutdown starts + // while we were waiting on `conn.lock()`. + let _sync = self.state.read().await; + if self.is_shutting_down() { + bail!("pool is shutting down"); + } return Ok(()); } if saved_session_id.is_none() { @@ -174,6 +235,14 @@ impl SessionPool { let mut state = self.state.write().await; + // Admission check inside the state write lock. This is atomic with + // `begin_shutdown`'s flag-flip + snapshot: a shutdown that started + // during our ACP spawn is caught here, and our work is thrown away + // rather than being added to a pool that is about to be torn down. + if self.is_shutting_down() { + bail!("pool is shutting down"); + } + // Another task may have created a healthy connection while we were // initializing this one. if let Some(existing) = state.active.get(thread_id).cloned() { @@ -186,11 +255,13 @@ impl SessionPool { warn!(thread_id, "stale connection, rebuilding"); drop(existing); state.active.remove(thread_id); + state.addresses.remove(thread_id); } if state.active.len() >= self.max_sessions { if let Some((key, expected_conn, _, sid)) = eviction_candidate { if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + state.addresses.remove(&key); info!(evicted = %key, "pool full, suspending oldest idle session"); if let Some(sid) = sid { state.suspended.insert(key, sid); @@ -216,6 +287,9 @@ impl SessionPool { if !cancel_session_id.is_empty() { state.cancel_handles.insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); } + state + .addresses + .insert(thread_id.to_string(), (channel.clone(), adapter.clone())); Ok(()) } @@ -324,6 +398,7 @@ impl SessionPool { let mut state = self.state.write().await; for (key, expected_conn, sid) in stale { if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + state.addresses.remove(&key); info!(thread_id = %key, "cleaning up idle session"); if let Some(sid) = sid { state.suspended.insert(key, sid); @@ -336,6 +411,7 @@ impl SessionPool { let mut state = self.state.write().await; let count = state.active.len(); state.active.clear(); // Drop impl kills process groups + state.addresses.clear(); info!(count, "pool shutdown complete"); } } diff --git a/src/adapter.rs b/src/adapter.rs index 189e98c..864a5c6 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -1,8 +1,8 @@ use anyhow::Result; use async_trait::async_trait; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::sync::Arc; -use tracing::error; +use tracing::{error, info, warn}; use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool}; use crate::config::ReactionsConfig; @@ -32,7 +32,7 @@ pub struct MessageRef { } /// Sender identity injected into prompts for downstream agent context. -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SenderContext { pub schema: String, pub sender_id: String, @@ -154,11 +154,33 @@ impl AdapterRouter { .unwrap_or(&thread_channel.channel_id) ); - if let Err(e) = self.pool.get_or_create(&thread_key).await { - let msg = format_user_error(&e.to_string()); - let _ = adapter - .send_message(thread_channel, &format!("⚠️ {msg}")) - .await; + // Session admission. The pool itself is authoritative: it rejects + // with an error if `begin_shutdown` has already fired, and on success + // stores the `ChannelRef` + adapter so `broadcast_shutdown` can reach + // this thread without the router keeping a parallel cache. + if let Err(e) = self.pool.get_or_create(&thread_key, thread_channel, adapter).await { + if self.pool.is_shutting_down() { + // Don't send the shutdown rejection back to bot-authored events. + // Slack (and potentially any other platform that doesn't drop + // the bot's own posts) would deliver our broadcast message as a + // new bot event, route it here during the shutdown window, and + // we'd reply with another bot-authored rejection — looping until + // the bot-turn cap trips. Human senders still get the notice. + let sender_is_bot = serde_json::from_str::(sender_json) + .map(|sender| sender.is_bot) + .unwrap_or(false); + if !sender_is_bot { + let _ = adapter + .send_message( + thread_channel, + "⚠️ Bot is shutting down and cannot accept new messages right now.", + ) + .await; + } + return Ok(()); + } + let msg = format!("⚠️ {}", format_user_error(&e.to_string())); + let _ = adapter.send_message(thread_channel, &msg).await; error!("pool error: {e}"); return Err(e); } @@ -209,6 +231,58 @@ impl AdapterRouter { result } + /// Broadcast a short notification to every active thread, across all + /// configured adapters, before the broker shuts down. Sends happen in + /// parallel and are capped by `timeout`; the call returns early if the + /// deadline is hit so shutdown itself is never blocked by a slow platform. + /// + /// Delivery is best-effort: evicted sessions whose `ChannelRef` is still + /// in the cache still receive the notification, which is the behavior we + /// want (the user saw the thread was in flight; they deserve to know the + /// broker is going away). + pub async fn broadcast_shutdown(&self, message: &str, timeout: std::time::Duration) { + // The pool owns both the flag flip and the live-session snapshot and + // performs both atomically under its state write lock. Any message + // admitted before us is in the snapshot; any that comes after sees + // the flag inside the same lock and returns an admission error that + // `handle_message` surfaces inline. + let snapshot = self.pool.begin_shutdown().await; + + if snapshot.is_empty() { + return; + } + + info!(count = snapshot.len(), "broadcasting shutdown notification"); + + let mut set = tokio::task::JoinSet::new(); + for (thread_key, channel, adapter) in snapshot { + let message = message.to_string(); + set.spawn(async move { + if let Err(e) = adapter.send_message(&channel, &message).await { + warn!(thread_key, error = %e, "failed to post shutdown notification"); + } + }); + } + + let deadline = tokio::time::sleep(timeout); + tokio::pin!(deadline); + loop { + tokio::select! { + biased; + _ = &mut deadline => { + warn!(timeout_ms = timeout.as_millis() as u64, "shutdown broadcast timed out; remaining sends cancelled"); + set.shutdown().await; + return; + } + next = set.join_next() => { + if next.is_none() { + return; + } + } + } + } + } + async fn stream_prompt( &self, adapter: &Arc, diff --git a/src/discord.rs b/src/discord.rs index ac94759..105651e 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -517,7 +517,14 @@ impl EventHandler for Handler { tokio::spawn(async move { let sender_json = serde_json::to_string(&sender).unwrap(); if let Err(e) = router - .handle_message(&adapter, &thread_channel, &sender_json, &prompt, extra_blocks, &trigger_msg) + .handle_message( + &adapter, + &thread_channel, + &sender_json, + &prompt, + extra_blocks, + &trigger_msg, + ) .await { error!("handle_message error: {e}"); diff --git a/src/main.rs b/src/main.rs index 53927f4..5633785 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,15 @@ use std::path::PathBuf; use std::sync::Arc; use tracing::{error, info, warn}; +/// Neutral shutdown notification broadcast to every active thread. Wording +/// deliberately avoids "restarting" because `helm uninstall` / final-stop +/// can't be distinguished from a rolling restart at signal time. +const SHUTDOWN_MSG: &str = "⚠️ Bot is shutting down. Context will reset on return."; + +/// Broadcast deadline. Shutdown itself must never block on a slow platform, +/// so incomplete sends are dropped once this elapses. +const SHUTDOWN_BROADCAST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10); + #[derive(Parser)] #[command(name = "openab")] #[command(about = "Multi-platform ACP agent broker (Discord, Slack)", long_about = None)] @@ -176,7 +185,7 @@ async fn main() -> anyhow::Result<()> { ); let handler = discord::Handler { - router, + router: router.clone(), allow_all_channels, allow_all_users, allowed_channels, @@ -201,21 +210,31 @@ async fn main() -> anyhow::Result<()> { .event_handler(handler) .await?; - // Graceful Discord shutdown on ctrl_c + // Graceful shutdown on SIGINT or SIGTERM: wait for the signal, + // broadcast to every active thread (Discord + Slack), then stop + // Discord shards. `client.start()` is the foreground blocker here, + // so this handler runs as a spawned task. let shard_manager = client.shard_manager.clone(); + let shutdown_router = router.clone(); tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - info!("shutdown signal received"); + wait_for_shutdown_signal().await; + shutdown_router + .broadcast_shutdown(SHUTDOWN_MSG, SHUTDOWN_BROADCAST_TIMEOUT) + .await; shard_manager.shutdown_all().await; }); info!("discord bot running"); client.start().await?; } else { - // No Discord — just wait for ctrl_c - info!("running without discord, press ctrl+c to stop"); - tokio::signal::ctrl_c().await.ok(); - info!("shutdown signal received"); + // No Discord — this task itself blocks on the shutdown signal, + // then broadcasts before falling through to cleanup. Slack-only + // deployments need SIGTERM + broadcast just like Discord. + info!("running without discord, waiting for shutdown signal"); + wait_for_shutdown_signal().await; + router + .broadcast_shutdown(SHUTDOWN_MSG, SHUTDOWN_BROADCAST_TIMEOUT) + .await; } // Cleanup @@ -233,6 +252,28 @@ async fn main() -> anyhow::Result<()> { } } +/// Wait for SIGINT (ctrl_c) or, on Unix, SIGTERM (systemctl stop, docker stop, +/// kill). Without SIGTERM handling the broker would be killed outright by +/// service managers and skip the shutdown broadcast, so both signals route +/// here on Unix. On non-Unix targets `tokio::signal::unix` is unavailable, so +/// we fall back to ctrl_c alone. +#[cfg(unix)] +async fn wait_for_shutdown_signal() { + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("install SIGTERM handler"); + tokio::select! { + _ = tokio::signal::ctrl_c() => {} + _ = sigterm.recv() => {} + } + info!("shutdown signal received"); +} + +#[cfg(not(unix))] +async fn wait_for_shutdown_signal() { + tokio::signal::ctrl_c().await.ok(); + info!("shutdown signal received"); +} + fn parse_id_set(raw: &[String], label: &str) -> anyhow::Result> { let set: HashSet = raw .iter() diff --git a/src/slack.rs b/src/slack.rs index 92155e2..0675d75 100644 --- a/src/slack.rs +++ b/src/slack.rs @@ -953,7 +953,14 @@ async fn handle_message( let adapter_dyn: Arc = adapter.clone(); if let Err(e) = router - .handle_message(&adapter_dyn, &thread_channel, &sender_json, &prompt, extra_blocks, &trigger_msg) + .handle_message( + &adapter_dyn, + &thread_channel, + &sender_json, + &prompt, + extra_blocks, + &trigger_msg, + ) .await { error!("Slack handle_message error: {e}");