diff --git a/docs/discord.md b/docs/discord.md index b35ed6f..7b7896f 100644 --- a/docs/discord.md +++ b/docs/discord.md @@ -217,7 +217,7 @@ allow_bot_messages = "mentions" To prevent runaway bot-to-bot loops, OpenAB enforces two layers of protection: - **Soft limit** (`max_bot_turns`, default: 20) — consecutive bot turns without human intervention. When reached, the bot sends a warning and stops responding. A human message in the thread resets the counter. -- **Hard limit** (100, not configurable) — absolute cap on total bot turns per thread. When reached, bot-to-bot conversation is permanently stopped in that thread. +- **Hard limit** (100, not configurable) — absolute cap on bot turns between human interventions. When reached, bot-to-bot conversation stops until a human replies. ```toml [discord] diff --git a/src/discord.rs b/src/discord.rs index 65002dd..35508c8 100644 --- a/src/discord.rs +++ b/src/discord.rs @@ -126,8 +126,8 @@ pub struct Handler { pub session_ttl: std::time::Duration, /// Configurable soft limit on bot turns per thread (reset by human message). pub max_bot_turns: u32, - /// Per-thread counters: (soft_turns, hard_turns). Soft resets on human msg, hard never resets. - pub bot_turn_counts: tokio::sync::Mutex>, + /// Per-thread bot turn tracker. Both counters reset on human msg. + pub bot_turns: tokio::sync::Mutex, } impl Handler { @@ -396,35 +396,32 @@ impl EventHandler for Handler { // Bot turn limiting: track consecutive bot turns per thread. // Placed after all gating so only messages that will actually be // processed count toward the limit. - // Human message resets soft counter; hard counter never resets. + // Human message resets both soft and hard counters. { let thread_key = msg.channel_id.to_string(); - let mut counts = self.bot_turn_counts.lock().await; + let mut tracker = self.bot_turns.lock().await; if msg.author.bot { - let (soft, hard) = counts.entry(thread_key).or_insert((0, 0)); - *soft += 1; - *hard += 1; - if *hard >= HARD_BOT_TURN_LIMIT { - tracing::warn!(channel_id = %msg.channel_id, hard = *hard, "hard bot turn limit reached"); - let _ = msg.channel_id.say( - &ctx.http, - format!("🛑 Hard limit reached ({HARD_BOT_TURN_LIMIT}). Bot-to-bot conversation in this thread has been permanently stopped."), - ).await; - return; - } - if *soft >= self.max_bot_turns { - tracing::info!(channel_id = %msg.channel_id, soft = *soft, max = self.max_bot_turns, "soft bot turn limit reached"); - let _ = msg.channel_id.say( - &ctx.http, - format!("⚠️ Bot turn limit reached ({}/{}). A human must reply in this thread to continue bot-to-bot conversation.", *soft, self.max_bot_turns), - ).await; - return; + match tracker.on_bot_message(&thread_key) { + TurnResult::HardLimit => { + tracing::warn!(channel_id = %msg.channel_id, "hard bot turn limit reached"); + let _ = msg.channel_id.say( + &ctx.http, + format!("🛑 Hard limit reached ({HARD_BOT_TURN_LIMIT}). Bot-to-bot conversation in this thread has been permanently stopped."), + ).await; + return; + } + TurnResult::SoftLimit(n) => { + tracing::info!(channel_id = %msg.channel_id, turns = n, max = self.max_bot_turns, "soft bot turn limit reached"); + let _ = msg.channel_id.say( + &ctx.http, + format!("⚠️ Bot turn limit reached ({n}/{}). A human must reply in this thread to continue bot-to-bot conversation.", self.max_bot_turns), + ).await; + return; + } + TurnResult::Ok => {} } } else { - // Human message: reset soft counter - if let Some((soft, _)) = counts.get_mut(&thread_key) { - *soft = 0; - } + tracker.on_human_message(&thread_key); } } @@ -572,6 +569,46 @@ async fn get_or_create_thread( adapter.create_thread(&parent, &trigger_ref, &thread_name).await } +// --- Bot turn tracking --- + +#[derive(Debug, PartialEq, Eq)] +pub(crate) enum TurnResult { + Ok, + SoftLimit(u32), + HardLimit, +} + +pub(crate) struct BotTurnTracker { + soft_limit: u32, + counts: HashMap, +} + +impl BotTurnTracker { + pub fn new(soft_limit: u32) -> Self { + Self { soft_limit, counts: HashMap::new() } + } + + pub fn on_bot_message(&mut self, thread_id: &str) -> TurnResult { + let (soft, hard) = self.counts.entry(thread_id.to_string()).or_insert((0, 0)); + *soft += 1; + *hard += 1; + if *hard >= HARD_BOT_TURN_LIMIT { + TurnResult::HardLimit + } else if *soft >= self.soft_limit { + TurnResult::SoftLimit(*soft) + } else { + TurnResult::Ok + } + } + + pub fn on_human_message(&mut self, thread_id: &str) { + if let Some((soft, hard)) = self.counts.get_mut(thread_id) { + *soft = 0; + *hard = 0; + } + } +} + static ROLE_MENTION_RE: LazyLock = LazyLock::new(|| { regex::Regex::new(r"<@&\d+>").unwrap() }); @@ -586,3 +623,81 @@ fn resolve_mentions(content: &str, bot_id: UserId) -> String { let out = ROLE_MENTION_RE.replace_all(&out, "@(role)").to_string(); out.trim().to_string() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bot_turns_increment() { + let mut t = BotTurnTracker::new(5); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + + #[test] + fn soft_limit_triggers() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + #[test] + fn human_resets_both_counters() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + t.on_human_message("t1"); + // Both reset — can do 2 more before soft limit + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + } + + #[test] + fn hard_limit_triggers() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + } + + #[test] + fn hard_limit_resets_on_human() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + t.on_human_message("t1"); + // Hard counter reset — can go again + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + + #[test] + fn hard_before_soft_when_equal() { + let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT); + for _ in 0..HARD_BOT_TURN_LIMIT - 1 { + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + } + // soft == hard == HARD_BOT_TURN_LIMIT → hard wins + assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit); + } + + #[test] + fn threads_are_independent() { + let mut t = BotTurnTracker::new(3); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::Ok); + assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3)); + // t2 is unaffected + assert_eq!(t.on_bot_message("t2"), TurnResult::Ok); + } + + #[test] + fn human_on_unknown_thread_is_noop() { + let mut t = BotTurnTracker::new(5); + t.on_human_message("unknown"); // should not panic + } +} diff --git a/src/main.rs b/src/main.rs index bf4d1b2..28d38d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -172,7 +172,7 @@ async fn main() -> anyhow::Result<()> { multibot_threads: tokio::sync::Mutex::new(std::collections::HashMap::new()), session_ttl: std::time::Duration::from_secs(ttl_secs), max_bot_turns: discord_cfg.max_bot_turns, - bot_turn_counts: tokio::sync::Mutex::new(std::collections::HashMap::new()), + bot_turns: tokio::sync::Mutex::new(discord::BotTurnTracker::new(discord_cfg.max_bot_turns)), }; let intents = GatewayIntents::GUILD_MESSAGES