diff --git a/.gitignore b/.gitignore index 16dd32e..c58f29e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ config.toml *.swp .DS_Store +.env +.kiro/ diff --git a/charts/openab/templates/configmap.yaml b/charts/openab/templates/configmap.yaml index e0c4a61..ed729a0 100644 --- a/charts/openab/templates/configmap.yaml +++ b/charts/openab/templates/configmap.yaml @@ -42,6 +42,13 @@ data: {{- if $cfg.discord.trustedBotIds }} trusted_bot_ids = {{ $cfg.discord.trustedBotIds | toJson }} {{- end }} + {{- /* allowUserMessages: controls whether the bot requires @mention in threads (Discord) */ -}} + {{- if $cfg.discord.allowUserMessages }} + {{- if not (has $cfg.discord.allowUserMessages (list "involved" "mentions")) }} + {{- fail (printf "agents.%s.discord.allowUserMessages must be one of: involved, mentions — got: %s" $name $cfg.discord.allowUserMessages) }} + {{- end }} + allow_user_messages = {{ $cfg.discord.allowUserMessages | toJson }} {{- /* involved (default): respond in bot's threads without @mention | mentions: always require @mention */ -}} + {{- end }} {{- end }} {{- if and ($cfg.slack).enabled }} @@ -73,7 +80,7 @@ data: {{- if not (has ($cfg.slack).allowUserMessages (list "involved" "mentions")) }} {{- fail (printf "agents.%s.slack.allowUserMessages must be one of: involved, mentions — got: %s" $name ($cfg.slack).allowUserMessages) }} {{- end }} - allow_user_messages = {{ ($cfg.slack).allowUserMessages | toJson }} + allow_user_messages = {{ ($cfg.slack).allowUserMessages | toJson }} {{- /* involved (default): respond in bot's threads without @mention | mentions: always require @mention */ -}} {{- end }} {{- end }} diff --git a/src/adapter.rs b/src/adapter.rs index 2c2d096..5c39369 100644 --- a/src/adapter.rs +++ b/src/adapter.rs @@ -2,7 +2,6 @@ use anyhow::Result; use async_trait::async_trait; use serde::Serialize; use std::sync::Arc; -use tokio::sync::watch; use tracing::error; use crate::acp::{classify_notification, AcpEvent, ContentBlock, SessionPool}; @@ -41,6 +40,10 @@ pub struct SenderContext { pub display_name: String, pub channel: String, pub channel_id: String, + /// Thread identifier, if the message is inside a thread. + /// Slack: thread_ts. Discord: None (threads are separate channels). + #[serde(skip_serializing_if = "Option::is_none")] + pub thread_id: Option, pub is_bot: bool, } @@ -57,9 +60,6 @@ pub trait ChatAdapter: Send + Sync + 'static { /// Send a new message, returns a reference to the sent message. async fn send_message(&self, channel: &ChannelRef, content: &str) -> Result; - /// Edit an existing message in-place. - async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()>; - /// Create a thread from a trigger message, returns the thread channel ref. async fn create_thread( &self, @@ -73,6 +73,17 @@ pub trait ChatAdapter: Send + Sync + 'static { /// Remove a reaction/emoji from a message. async fn remove_reaction(&self, msg: &MessageRef, emoji: &str) -> Result<()>; + + /// Edit an existing message in-place (for streaming updates). + /// Default: unsupported (send-once only). + async fn edit_message(&self, _msg: &MessageRef, _content: &str) -> Result<()> { + Err(anyhow::anyhow!("edit_message not supported")) + } + + /// Whether this adapter should use streaming edit (true) or send-once (false). + fn use_streaming(&self) -> bool { + false + } } // --- AdapterRouter --- @@ -130,8 +141,6 @@ impl AdapterRouter { } } - let thinking_msg = adapter.send_message(thread_channel, "...").await?; - let thread_key = format!( "{}:{}", adapter.platform(), @@ -144,7 +153,7 @@ impl AdapterRouter { if let Err(e) = self.pool.get_or_create(&thread_key).await { let msg = format_user_error(&e.to_string()); let _ = adapter - .edit_message(&thinking_msg, &format!("⚠️ {msg}")) + .send_message(thread_channel, &format!("⚠️ {msg}")) .await; error!("pool error: {e}"); return Err(e); @@ -165,7 +174,6 @@ impl AdapterRouter { &thread_key, content_blocks, thread_channel, - &thinking_msg, reactions.clone(), ) .await; @@ -190,7 +198,7 @@ impl AdapterRouter { if let Err(ref e) = result { let _ = adapter - .edit_message(&thinking_msg, &format!("⚠️ {e}")) + .send_message(thread_channel, &format!("⚠️ {e}")) .await; } @@ -203,13 +211,12 @@ impl AdapterRouter { thread_key: &str, content_blocks: Vec, thread_channel: &ChannelRef, - thinking_msg: &MessageRef, reactions: Arc, ) -> Result<()> { let adapter = adapter.clone(); let thread_channel = thread_channel.clone(); - let msg_ref = thinking_msg.clone(); let message_limit = adapter.message_limit(); + let streaming = adapter.use_streaming(); self.pool .with_connection(thread_key, |conn| { @@ -221,13 +228,6 @@ impl AdapterRouter { let (mut rx, _) = conn.session_prompt(content_blocks).await?; reactions.set_thinking().await; - let initial = if reset { - "⚠️ _Session expired, starting fresh..._\n\n...".to_string() - } else { - "...".to_string() - }; - let (buf_tx, buf_rx) = watch::channel(initial); - let mut text_buf = String::new(); let mut tool_lines: Vec = Vec::new(); @@ -235,43 +235,44 @@ impl AdapterRouter { text_buf.push_str("⚠️ _Session expired, starting fresh..._\n\n"); } - // Spawn edit-streaming task — only edits the single message, never sends new ones. - // Long content is truncated during streaming; final multi-message split happens after. - let streaming_limit = message_limit.saturating_sub(100); - let edit_handle = { - let adapter = adapter.clone(); - let msg_ref = msg_ref.clone(); - let mut buf_rx = buf_rx.clone(); + // Streaming edit: send placeholder, spawn edit loop + let (buf_tx, placeholder_msg) = if streaming { + let initial = if reset { + "⚠️ _Session expired, starting fresh..._\n\n…".to_string() + } else { + "…".to_string() + }; + let msg = adapter.send_message(&thread_channel, &initial).await?; + let (tx, rx) = tokio::sync::watch::channel(initial); + let edit_adapter = adapter.clone(); + let edit_msg = msg.clone(); + let limit = message_limit; + let mut buf_rx = rx; tokio::spawn(async move { - let mut last_content = String::new(); + let mut last = String::new(); loop { tokio::time::sleep(std::time::Duration::from_millis(1500)).await; if buf_rx.has_changed().unwrap_or(false) { let content = buf_rx.borrow_and_update().clone(); - if content != last_content { - let display = if content.chars().count() > streaming_limit { - // Tail-priority: keep the last N chars so user - // sees the most recent agent output - let total = content.chars().count(); - let skip = total - streaming_limit; - let truncated: String = content.chars().skip(skip).collect(); - format!("…(truncated)\n{truncated}") + if content != last { + let display = if content.chars().count() > limit - 100 { + format!("…{}", format::truncate_chars_tail(&content, limit - 100)) } else { content.clone() }; - let _ = adapter.edit_message(&msg_ref, &display).await; - last_content = content; + let _ = edit_adapter.edit_message(&edit_msg, &display).await; + last = content; } } - if buf_rx.has_changed().is_err() { - break; - } + if buf_rx.has_changed().is_err() { break; } } - }) + }); + (Some(tx), Some(msg)) + } else { + (None, None) }; // Process ACP notifications - let mut got_first_text = false; let mut response_error: Option = None; while let Some(notification) = rx.recv().await { if notification.id.is_some() { @@ -284,12 +285,10 @@ impl AdapterRouter { if let Some(event) = classify_notification(¬ification) { match event { AcpEvent::Text(t) => { - if !got_first_text { - got_first_text = true; - } text_buf.push_str(&t); - let _ = - buf_tx.send(compose_display(&tool_lines, &text_buf, true)); + if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display(&tool_lines, &text_buf, true)); + } } AcpEvent::Thinking => { reactions.set_thinking().await; @@ -307,8 +306,9 @@ impl AdapterRouter { state: ToolState::Running, }); } - let _ = - buf_tx.send(compose_display(&tool_lines, &text_buf, true)); + if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display(&tool_lines, &text_buf, true)); + } } AcpEvent::ToolDone { id, title, status } => { reactions.set_thinking().await; @@ -329,8 +329,9 @@ impl AdapterRouter { state: new_state, }); } - let _ = - buf_tx.send(compose_display(&tool_lines, &text_buf, true)); + if let Some(tx) = &buf_tx { + let _ = tx.send(compose_display(&tool_lines, &text_buf, true)); + } } _ => {} } @@ -338,10 +339,10 @@ impl AdapterRouter { } conn.prompt_done().await; + // Stop the edit loop drop(buf_tx); - let _ = edit_handle.await; - // Final edit with complete content + // Build final content let final_content = compose_display(&tool_lines, &text_buf, false); let final_content = if final_content.is_empty() { if let Some(err) = response_error { @@ -356,14 +357,18 @@ impl AdapterRouter { }; let chunks = format::split_message(&final_content, message_limit); - let mut current_msg = msg_ref; - for (i, chunk) in chunks.iter().enumerate() { - if i == 0 { - let _ = adapter.edit_message(¤t_msg, chunk).await; - } else if let Ok(new_msg) = - adapter.send_message(&thread_channel, chunk).await - { - current_msg = new_msg; + if let Some(msg) = placeholder_msg { + // Streaming: edit first chunk into placeholder, send rest as new messages + if let Some(first) = chunks.first() { + let _ = adapter.edit_message(&msg, first).await; + } + for chunk in chunks.iter().skip(1) { + let _ = adapter.send_message(&thread_channel, chunk).await; + } + } else { + // Send-once: all chunks as new messages + for chunk in &chunks { + let _ = adapter.send_message(&thread_channel, chunk).await; } } diff --git a/src/discord.rs b/src/discord.rs index dead2d2..11c0e2e 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -5,7 +5,7 @@ use crate::format; use crate::media; use async_trait::async_trait; use std::sync::LazyLock; -use serenity::builder::{CreateThread, EditMessage}; +use serenity::builder::CreateThread; use serenity::http::Http; use serenity::model::channel::{AutoArchiveDuration, Message, ReactionType}; use serenity::model::gateway::Ready; @@ -54,19 +54,6 @@ impl ChatAdapter for DiscordAdapter { }) } - async fn edit_message(&self, msg: &MessageRef, content: &str) -> anyhow::Result<()> { - let ch_id: u64 = msg.channel.channel_id.parse()?; - let msg_id: u64 = msg.message_id.parse()?; - ChannelId::new(ch_id) - .edit_message( - &self.http, - MessageId::new(msg_id), - EditMessage::new().content(content), - ) - .await?; - Ok(()) - } - async fn create_thread( &self, channel: &ChannelRef, @@ -216,11 +203,7 @@ impl EventHandler for Handler { self.allowed_channels.is_empty() || self.allowed_channels.contains(&channel_id); let is_mentioned = msg.mentions_user_id(bot_id) - || msg.content.contains(&format!("<@{}>", bot_id)) - || msg - .mention_roles - .iter() - .any(|r| msg.content.contains(&format!("<@&{}>", r))); + || msg.content.contains(&format!("<@{}>", bot_id)); // Bot message gating (from upstream #321) if msg.author.bot { @@ -361,6 +344,7 @@ impl EventHandler for Handler { display_name: display_name.to_string(), channel: "discord".into(), channel_id: msg.channel_id.to_string(), + thread_id: None, is_bot: msg.author.bot, }; diff --git a/src/format.rs b/src/format.rs index 56f0fad..fbf76c6 100644 --- a/src/format.rs +++ b/src/format.rs @@ -60,3 +60,13 @@ pub fn shorten_thread_name(prompt: &str) -> String { } } + +/// Truncate a string to at most `limit` Unicode characters, keeping the tail +/// (most recent output) for better streaming UX. +pub fn truncate_chars_tail(s: &str, limit: usize) -> String { + let total = s.chars().count(); + if total <= limit { + return s.to_string(); + } + s.chars().skip(total - limit).collect() +} diff --git a/src/slack.rs b/src/slack.rs index fabf31b..15d7e1a 100644 --- a/src/slack.rs +++ b/src/slack.rs @@ -63,10 +63,12 @@ pub struct SlackAdapter { participated_threads: tokio::sync::Mutex>, /// TTL for participation cache entries (matches session_ttl_hours from config). session_ttl: std::time::Duration, + /// Controls streaming behavior: Off → streaming edit, Mentions/All → send-once. + allow_bot_messages: AllowBots, } impl SlackAdapter { - pub fn new(bot_token: String, session_ttl: std::time::Duration) -> Self { + pub fn new(bot_token: String, session_ttl: std::time::Duration, allow_bot_messages: AllowBots) -> Self { Self { client: reqwest::Client::new(), bot_token, @@ -75,6 +77,7 @@ impl SlackAdapter { bot_id_cache: tokio::sync::Mutex::new(HashMap::new()), participated_threads: tokio::sync::Mutex::new(HashMap::new()), session_ttl, + allow_bot_messages, } } @@ -318,19 +321,6 @@ impl ChatAdapter for SlackAdapter { }) } - async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { - let mrkdwn = markdown_to_mrkdwn(content); - self.api_post( - "chat.update", - serde_json::json!({ - "channel": msg.channel.channel_id, - "ts": msg.message_id, - "text": mrkdwn, - }), - ) - .await?; - Ok(()) - } async fn create_thread( &self, @@ -382,6 +372,68 @@ impl ChatAdapter for SlackAdapter { Err(e) => Err(e), } } + + async fn edit_message(&self, msg: &MessageRef, content: &str) -> Result<()> { + let mrkdwn = markdown_to_mrkdwn(content); + self.api_post( + "chat.update", + serde_json::json!({ + "channel": msg.channel.channel_id, + "ts": msg.message_id, + "text": mrkdwn, + }), + ) + .await?; + Ok(()) + } + + fn use_streaming(&self) -> bool { + self.allow_bot_messages == AllowBots::Off + } +} + +// --- Per-thread async queue (inspired by OpenClaw's KeyedAsyncQueue) --- + +/// Serialize async work per key while allowing unrelated keys to run concurrently. +/// Same-key tasks execute in FIFO order; different keys run in parallel. +/// Idle keys are cleaned up automatically after the last task settles. +struct KeyedAsyncQueue { + tails: tokio::sync::Mutex>>, +} + +impl KeyedAsyncQueue { + fn new() -> Self { + Self { + tails: tokio::sync::Mutex::new(HashMap::new()), + } + } + + /// Acquire a per-key permit. The returned guard must be held for the + /// duration of the async work. Dropping it allows the next queued task + /// for the same key to proceed. + /// + /// Performs lazy cleanup of idle semaphores to prevent unbounded growth + /// in long-running deployments. + async fn acquire(&self, key: &str) -> Option { + let sem = { + let mut tails = self.tails.lock().await; + // Lazy cleanup: evict idle entries (available_permits == 1 means no one is holding or waiting) + if tails.len() > 100 { + tails.retain(|_, sem| Arc::strong_count(sem) > 1 || sem.available_permits() < 1); + } + tails + .entry(key.to_string()) + .or_insert_with(|| Arc::new(tokio::sync::Semaphore::new(1))) + .clone() + }; + match sem.acquire_owned().await { + Ok(permit) => Some(permit), + Err(e) => { + warn!(key, error = %e, "semaphore closed, skipping message"); + None + } + } + } } // --- Socket Mode event loop --- @@ -405,7 +457,8 @@ pub async fn run_slack_adapter( router: Arc, mut shutdown_rx: watch::Receiver, ) -> Result<()> { - let adapter = Arc::new(SlackAdapter::new(bot_token.clone(), session_ttl)); + let adapter = Arc::new(SlackAdapter::new(bot_token.clone(), session_ttl, allow_bot_messages)); + let queue = Arc::new(KeyedAsyncQueue::new()); loop { // Check for shutdown before (re)connecting @@ -455,6 +508,26 @@ pub async fn run_slack_adapter( let event_type = event["type"].as_str().unwrap_or(""); match event_type { "app_mention" => { + // Apply bot gating for app_mention events (same rules as message events) + let is_bot = event["bot_id"].is_string() + || event["subtype"].as_str() == Some("bot_message"); + if is_bot { + match allow_bot_messages { + AllowBots::Off => { continue; } + AllowBots::Mentions | AllowBots::All => { + if !trusted_bot_ids.is_empty() { + let event_bot_id = event["bot_id"].as_str().unwrap_or(""); + let resolved = adapter.resolve_bot_user_id(event_bot_id).await; + let is_trusted = resolved.as_ref() + .is_some_and(|uid| trusted_bot_ids.contains(uid.as_str())); + if !is_trusted { + debug!(event_bot_id, resolved = ?resolved, "bot not in trusted_bot_ids, ignoring app_mention"); + continue; + } + } + } + } + } let event = event.clone(); let adapter = adapter.clone(); let bot_token = bot_token.clone(); @@ -462,7 +535,18 @@ pub async fn run_slack_adapter( let allowed_users = allowed_users.clone(); let stt_config = stt_config.clone(); let router = router.clone(); + let queue = queue.clone(); + // Queue key: thread_ts if already in a thread, otherwise ts. + // app_mention always has a channel context, so ts alone + // is unique enough (unlike message events in DMs where + // we prefix with channel_id to avoid ts collisions). + let queue_key = event["thread_ts"] + .as_str() + .or_else(|| event["ts"].as_str()) + .unwrap_or("") + .to_string(); tokio::spawn(async move { + let Some(_permit) = queue.acquire(&queue_key).await else { return }; handle_message( &event, true, @@ -597,7 +681,7 @@ pub async fn run_slack_adapter( } } - // Dispatch to handle_message + // Dispatch to handle_message (serialized per thread) let event = event.clone(); let adapter = adapter.clone(); let bot_token = bot_token.clone(); @@ -605,7 +689,19 @@ pub async fn run_slack_adapter( let allowed_users = allowed_users.clone(); let stt_config = stt_config.clone(); let router = router.clone(); + let queue = queue.clone(); + // Queue key: thread_ts if in a thread, otherwise channel:ts. + // Prefixed with channel_id for non-thread messages because + // DMs and channels can have overlapping ts values — the + // prefix ensures keys are globally unique. + let queue_key = event["thread_ts"] + .as_str() + .map(|s| s.to_string()) + .unwrap_or_else(|| { + format!("{}:{}", channel_id, event["ts"].as_str().unwrap_or("")) + }); tokio::spawn(async move { + let Some(_permit) = queue.acquire(&queue_key).await else { return }; handle_message( &event, is_dm, @@ -813,6 +909,7 @@ async fn handle_message( display_name, channel: "slack".into(), channel_id: channel_id.clone(), + thread_id: thread_ts.clone(), is_bot: is_bot_msg, };