Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions src/acp/pool.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::acp::connection::AcpConnection;
use crate::adapter::{ChannelRef, ChatAdapter};
use crate::config::AgentConfig;
use anyhow::{anyhow, Result};
use anyhow::{anyhow, bail, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
Expand All @@ -12,6 +13,11 @@ use tracing::{info, warn};
struct PoolState {
/// Active connections: thread_key → AcpConnection handle.
active: HashMap<String, Arc<Mutex<AcpConnection>>>,
/// Addressing info for each active thread. Populated alongside `active`
/// and pruned together. Used by `begin_shutdown` so the broker can post
/// a notification to every live session without the adapter layer
/// maintaining a parallel cache. Invariant: `addresses.keys() == active.keys()`.
addresses: HashMap<String, (ChannelRef, Arc<dyn ChatAdapter>)>,
/// Suspended sessions: thread_key → ACP sessionId.
/// Saved on eviction so sessions can be resumed via `session/load`.
suspended: HashMap<String, String>,
Expand All @@ -24,6 +30,10 @@ pub struct SessionPool {
state: RwLock<PoolState>,
config: AgentConfig,
max_sessions: usize,
/// Flipped by `begin_shutdown` to reject new admissions. Checked inside
/// `get_or_create` under the state write lock so admission and snapshot
/// are atomic.
shutting_down: std::sync::atomic::AtomicBool,
}

type EvictionCandidate = (
Expand Down Expand Up @@ -62,15 +72,53 @@ impl SessionPool {
Self {
state: RwLock::new(PoolState {
active: HashMap::new(),
addresses: HashMap::new(),
suspended: HashMap::new(),
creating: HashMap::new(),
}),
config,
max_sessions,
shutting_down: std::sync::atomic::AtomicBool::new(false),
}
}

pub async fn get_or_create(&self, thread_id: &str) -> Result<()> {
/// True once `begin_shutdown` has been called. Router uses this to show a
/// shutdown-specific message instead of a generic pool error when
/// `get_or_create` rejects admission.
pub fn is_shutting_down(&self) -> bool {
self.shutting_down
.load(std::sync::atomic::Ordering::Acquire)
}

/// Flip the pool into shutting-down state and return a snapshot of every
/// live session's addressing info. Takes the state write lock so the
/// snapshot is atomic with respect to in-flight `get_or_create` calls:
/// any admission that committed before us is included; any that comes
/// after us sees the flag inside the same lock and rejects.
pub async fn begin_shutdown(&self) -> Vec<(String, ChannelRef, Arc<dyn ChatAdapter>)> {
let state = self.state.write().await;
self.shutting_down
.store(true, std::sync::atomic::Ordering::Release);
state
.addresses
.iter()
.map(|(k, (c, a))| (k.clone(), c.clone(), a.clone()))
.collect()
}

pub async fn get_or_create(
&self,
thread_id: &str,
channel: &ChannelRef,
adapter: &Arc<dyn ChatAdapter>,
) -> Result<()> {
// Fast-fail: avoid spawning a fresh ACP process if shutdown is already
// in progress. The authoritative check happens again under the state
// write lock below so we also catch shutdowns that start mid-spawn.
if self.is_shutting_down() {
bail!("pool is shutting down");
}

let create_gate = {
let mut state = self.state.write().await;
get_or_insert_gate(&mut state.creating, thread_id)
Expand All @@ -79,6 +127,9 @@ impl SessionPool {

let (existing, saved_session_id) = {
let state = self.state.read().await;
if self.is_shutting_down() {
bail!("pool is shutting down");
}
(
state.active.get(thread_id).cloned(),
state.suspended.get(thread_id).cloned(),
Expand All @@ -90,6 +141,16 @@ impl SessionPool {
if let Some(conn) = existing.clone() {
let conn = conn.lock().await;
if conn.alive() {
// Re-check shutdown state after waiting on the per-connection
// mutex. Taking `state.read()` synchronizes us with
// `begin_shutdown`'s write-lock flag flip, so the flag value
// we see here reflects every `begin_shutdown` that has
// committed. This closes the race where shutdown starts
// while we were waiting on `conn.lock()`.
let _sync = self.state.read().await;
if self.is_shutting_down() {
bail!("pool is shutting down");
}
return Ok(());
}
if saved_session_id.is_none() {
Expand Down Expand Up @@ -167,6 +228,14 @@ impl SessionPool {

let mut state = self.state.write().await;

// Admission check inside the state write lock. This is atomic with
// `begin_shutdown`'s flag-flip + snapshot: a shutdown that started
// during our ACP spawn is caught here, and our work is thrown away
// rather than being added to a pool that is about to be torn down.
if self.is_shutting_down() {
bail!("pool is shutting down");
}

// Another task may have created a healthy connection while we were
// initializing this one.
if let Some(existing) = state.active.get(thread_id).cloned() {
Expand All @@ -179,11 +248,13 @@ impl SessionPool {
warn!(thread_id, "stale connection, rebuilding");
drop(existing);
state.active.remove(thread_id);
state.addresses.remove(thread_id);
}

if state.active.len() >= self.max_sessions {
if let Some((key, expected_conn, _, sid)) = eviction_candidate {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
state.addresses.remove(&key);
info!(evicted = %key, "pool full, suspending oldest idle session");
if let Some(sid) = sid {
state.suspended.insert(key, sid);
Expand All @@ -206,6 +277,9 @@ impl SessionPool {

state.suspended.remove(thread_id);
state.active.insert(thread_id.to_string(), new_conn);
state
.addresses
.insert(thread_id.to_string(), (channel.clone(), adapter.clone()));
Ok(())
}

Expand Down Expand Up @@ -261,6 +335,7 @@ impl SessionPool {
let mut state = self.state.write().await;
for (key, expected_conn, sid) in stale {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
state.addresses.remove(&key);
info!(thread_id = %key, "cleaning up idle session");
if let Some(sid) = sid {
state.suspended.insert(key, sid);
Expand All @@ -273,6 +348,7 @@ impl SessionPool {
let mut state = self.state.write().await;
let count = state.active.len();
state.active.clear(); // Drop impl kills process groups
state.addresses.clear();
info!(count, "pool shutdown complete");
}
}
Expand Down
90 changes: 82 additions & 8 deletions src/adapter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::Result;
use async_trait::async_trait;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::error;
use tracing::{error, info, warn};

use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool};
use crate::config::ReactionsConfig;
Expand Down Expand Up @@ -32,7 +32,7 @@ pub struct MessageRef {
}

/// Sender identity injected into prompts for downstream agent context.
#[derive(Clone, Debug, Serialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SenderContext {
pub schema: String,
pub sender_id: String,
Expand Down Expand Up @@ -149,11 +149,33 @@ impl AdapterRouter {
.unwrap_or(&thread_channel.channel_id)
);

if let Err(e) = self.pool.get_or_create(&thread_key).await {
let msg = format_user_error(&e.to_string());
let _ = adapter
.send_message(thread_channel, &format!("⚠️ {msg}"))
.await;
// Session admission. The pool itself is authoritative: it rejects
// with an error if `begin_shutdown` has already fired, and on success
// stores the `ChannelRef` + adapter so `broadcast_shutdown` can reach
// this thread without the router keeping a parallel cache.
if let Err(e) = self.pool.get_or_create(&thread_key, thread_channel, adapter).await {
if self.pool.is_shutting_down() {
// Don't send the shutdown rejection back to bot-authored events.
// Slack (and potentially any other platform that doesn't drop
// the bot's own posts) would deliver our broadcast message as a
// new bot event, route it here during the shutdown window, and
// we'd reply with another bot-authored rejection — looping until
// the bot-turn cap trips. Human senders still get the notice.
let sender_is_bot = serde_json::from_str::<SenderContext>(sender_json)
.map(|sender| sender.is_bot)
.unwrap_or(false);
if !sender_is_bot {
let _ = adapter
.send_message(
thread_channel,
"⚠️ Bot is shutting down and cannot accept new messages right now.",
)
.await;
}
return Ok(());
}
let msg = format!("⚠️ {}", format_user_error(&e.to_string()));
let _ = adapter.send_message(thread_channel, &msg).await;
error!("pool error: {e}");
return Err(e);
}
Expand Down Expand Up @@ -204,6 +226,58 @@ impl AdapterRouter {
result
}

/// Broadcast a short notification to every active thread, across all
/// configured adapters, before the broker shuts down. Sends happen in
/// parallel and are capped by `timeout`; the call returns early if the
/// deadline is hit so shutdown itself is never blocked by a slow platform.
///
/// Delivery is best-effort: evicted sessions whose `ChannelRef` is still
/// in the cache still receive the notification, which is the behavior we
/// want (the user saw the thread was in flight; they deserve to know the
/// broker is going away).
pub async fn broadcast_shutdown(&self, message: &str, timeout: std::time::Duration) {
// The pool owns both the flag flip and the live-session snapshot and
// performs both atomically under its state write lock. Any message
// admitted before us is in the snapshot; any that comes after sees
// the flag inside the same lock and returns an admission error that
// `handle_message` surfaces inline.
let snapshot = self.pool.begin_shutdown().await;

if snapshot.is_empty() {
return;
}

info!(count = snapshot.len(), "broadcasting shutdown notification");

let mut set = tokio::task::JoinSet::new();
for (thread_key, channel, adapter) in snapshot {
let message = message.to_string();
set.spawn(async move {
if let Err(e) = adapter.send_message(&channel, &message).await {
warn!(thread_key, error = %e, "failed to post shutdown notification");
}
});
}

let deadline = tokio::time::sleep(timeout);
tokio::pin!(deadline);
loop {
tokio::select! {
biased;
_ = &mut deadline => {
warn!(timeout_ms = timeout.as_millis() as u64, "shutdown broadcast timed out; remaining sends cancelled");
set.shutdown().await;
return;
}
next = set.join_next() => {
if next.is_none() {
return;
}
}
}
}
}

async fn stream_prompt(
&self,
adapter: &Arc<dyn ChatAdapter>,
Expand Down
9 changes: 8 additions & 1 deletion src/discord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,14 @@ impl EventHandler for Handler {
tokio::spawn(async move {
let sender_json = serde_json::to_string(&sender).unwrap();
if let Err(e) = router
.handle_message(&adapter, &thread_channel, &sender_json, &prompt, extra_blocks, &trigger_msg)
.handle_message(
&adapter,
&thread_channel,
&sender_json,
&prompt,
extra_blocks,
&trigger_msg,
)
.await
{
error!("handle_message error: {e}");
Expand Down
57 changes: 49 additions & 8 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ use std::path::PathBuf;
use std::sync::Arc;
use tracing::{error, info};

/// Neutral shutdown notification broadcast to every active thread. Wording
/// deliberately avoids "restarting" because `helm uninstall` / final-stop
/// can't be distinguished from a rolling restart at signal time.
const SHUTDOWN_MSG: &str = "⚠️ Bot is shutting down. Context will reset on return.";

/// Broadcast deadline. Shutdown itself must never block on a slow platform,
/// so incomplete sends are dropped once this elapses.
const SHUTDOWN_BROADCAST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);

#[derive(Parser)]
#[command(name = "openab")]
#[command(about = "Multi-platform ACP agent broker (Discord, Slack)", long_about = None)]
Expand Down Expand Up @@ -160,7 +169,7 @@ async fn main() -> anyhow::Result<()> {
);

let handler = discord::Handler {
router,
router: router.clone(),
allowed_channels,
allowed_users,
stt_config: cfg.stt.clone(),
Expand All @@ -180,21 +189,31 @@ async fn main() -> anyhow::Result<()> {
.event_handler(handler)
.await?;

// Graceful Discord shutdown on ctrl_c
// Graceful shutdown on SIGINT or SIGTERM: wait for the signal,
// broadcast to every active thread (Discord + Slack), then stop
// Discord shards. `client.start()` is the foreground blocker here,
// so this handler runs as a spawned task.
let shard_manager = client.shard_manager.clone();
let shutdown_router = router.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
info!("shutdown signal received");
wait_for_shutdown_signal().await;
shutdown_router
.broadcast_shutdown(SHUTDOWN_MSG, SHUTDOWN_BROADCAST_TIMEOUT)
.await;
shard_manager.shutdown_all().await;
});

info!("discord bot running");
client.start().await?;
} else {
// No Discord — just wait for ctrl_c
info!("running without discord, press ctrl+c to stop");
tokio::signal::ctrl_c().await.ok();
info!("shutdown signal received");
// No Discord — this task itself blocks on the shutdown signal,
// then broadcasts before falling through to cleanup. Slack-only
// deployments need SIGTERM + broadcast just like Discord.
info!("running without discord, waiting for shutdown signal");
wait_for_shutdown_signal().await;
router
.broadcast_shutdown(SHUTDOWN_MSG, SHUTDOWN_BROADCAST_TIMEOUT)
.await;
}

// Cleanup
Expand All @@ -212,6 +231,28 @@ async fn main() -> anyhow::Result<()> {
}
}

/// Wait for SIGINT (ctrl_c) or, on Unix, SIGTERM (systemctl stop, docker stop,
/// kill). Without SIGTERM handling the broker would be killed outright by
/// service managers and skip the shutdown broadcast, so both signals route
/// here on Unix. On non-Unix targets `tokio::signal::unix` is unavailable, so
/// we fall back to ctrl_c alone.
#[cfg(unix)]
async fn wait_for_shutdown_signal() {
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("install SIGTERM handler");
tokio::select! {
_ = tokio::signal::ctrl_c() => {}
_ = sigterm.recv() => {}
}
info!("shutdown signal received");
}

#[cfg(not(unix))]
async fn wait_for_shutdown_signal() {
tokio::signal::ctrl_c().await.ok();
info!("shutdown signal received");
}

fn parse_id_set(raw: &[String], label: &str) -> anyhow::Result<HashSet<u64>> {
let set: HashSet<u64> = raw
.iter()
Expand Down
Loading
Loading