Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/discord.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ allow_bot_messages = "mentions"
To prevent runaway bot-to-bot loops, OpenAB enforces two layers of protection:

- **Soft limit** (`max_bot_turns`, default: 20) — consecutive bot turns without human intervention. When reached, the bot sends a warning and stops responding. A human message in the thread resets the counter.
- **Hard limit** (100, not configurable) — absolute cap on total bot turns per thread. When reached, bot-to-bot conversation is permanently stopped in that thread.
- **Hard limit** (100, not configurable) — absolute cap on bot turns between human interventions. When reached, bot-to-bot conversation stops until a human replies.

```toml
[discord]
Expand Down
167 changes: 141 additions & 26 deletions src/discord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ pub struct Handler {
pub session_ttl: std::time::Duration,
/// Configurable soft limit on bot turns per thread (reset by human message).
pub max_bot_turns: u32,
/// Per-thread counters: (soft_turns, hard_turns). Soft resets on human msg, hard never resets.
pub bot_turn_counts: tokio::sync::Mutex<HashMap<String, (u32, u32)>>,
/// Per-thread bot turn tracker. Both counters reset on human msg.
pub bot_turns: tokio::sync::Mutex<BotTurnTracker>,
}

impl Handler {
Expand Down Expand Up @@ -396,35 +396,32 @@ impl EventHandler for Handler {
// Bot turn limiting: track consecutive bot turns per thread.
// Placed after all gating so only messages that will actually be
// processed count toward the limit.
// Human message resets soft counter; hard counter never resets.
// Human message resets both soft and hard counters.
{
let thread_key = msg.channel_id.to_string();
let mut counts = self.bot_turn_counts.lock().await;
let mut tracker = self.bot_turns.lock().await;
if msg.author.bot {
let (soft, hard) = counts.entry(thread_key).or_insert((0, 0));
*soft += 1;
*hard += 1;
if *hard >= HARD_BOT_TURN_LIMIT {
tracing::warn!(channel_id = %msg.channel_id, hard = *hard, "hard bot turn limit reached");
let _ = msg.channel_id.say(
&ctx.http,
format!("🛑 Hard limit reached ({HARD_BOT_TURN_LIMIT}). Bot-to-bot conversation in this thread has been permanently stopped."),
).await;
return;
}
if *soft >= self.max_bot_turns {
tracing::info!(channel_id = %msg.channel_id, soft = *soft, max = self.max_bot_turns, "soft bot turn limit reached");
let _ = msg.channel_id.say(
&ctx.http,
format!("⚠️ Bot turn limit reached ({}/{}). A human must reply in this thread to continue bot-to-bot conversation.", *soft, self.max_bot_turns),
).await;
return;
match tracker.on_bot_message(&thread_key) {
TurnResult::HardLimit => {
tracing::warn!(channel_id = %msg.channel_id, "hard bot turn limit reached");
let _ = msg.channel_id.say(
&ctx.http,
format!("🛑 Hard limit reached ({HARD_BOT_TURN_LIMIT}). Bot-to-bot conversation in this thread has been permanently stopped."),
).await;
return;
}
TurnResult::SoftLimit(n) => {
tracing::info!(channel_id = %msg.channel_id, turns = n, max = self.max_bot_turns, "soft bot turn limit reached");
let _ = msg.channel_id.say(
&ctx.http,
format!("⚠️ Bot turn limit reached ({n}/{}). A human must reply in this thread to continue bot-to-bot conversation.", self.max_bot_turns),
).await;
return;
}
TurnResult::Ok => {}
}
} else {
// Human message: reset soft counter
if let Some((soft, _)) = counts.get_mut(&thread_key) {
*soft = 0;
}
tracker.on_human_message(&thread_key);
}
}

Expand Down Expand Up @@ -572,6 +569,46 @@ async fn get_or_create_thread(
adapter.create_thread(&parent, &trigger_ref, &thread_name).await
}

// --- Bot turn tracking ---

#[derive(Debug, PartialEq, Eq)]
pub(crate) enum TurnResult {
Ok,
SoftLimit(u32),
HardLimit,
}

pub(crate) struct BotTurnTracker {
soft_limit: u32,
counts: HashMap<String, (u32, u32)>,
}

impl BotTurnTracker {
pub fn new(soft_limit: u32) -> Self {
Self { soft_limit, counts: HashMap::new() }
}

pub fn on_bot_message(&mut self, thread_id: &str) -> TurnResult {
let (soft, hard) = self.counts.entry(thread_id.to_string()).or_insert((0, 0));
*soft += 1;
*hard += 1;
if *hard >= HARD_BOT_TURN_LIMIT {
TurnResult::HardLimit
} else if *soft >= self.soft_limit {
TurnResult::SoftLimit(*soft)
} else {
TurnResult::Ok
}
}

pub fn on_human_message(&mut self, thread_id: &str) {
if let Some((soft, hard)) = self.counts.get_mut(thread_id) {
*soft = 0;
*hard = 0;
}
}
}

static ROLE_MENTION_RE: LazyLock<regex::Regex> = LazyLock::new(|| {
regex::Regex::new(r"<@&\d+>").unwrap()
});
Expand All @@ -586,3 +623,81 @@ fn resolve_mentions(content: &str, bot_id: UserId) -> String {
let out = ROLE_MENTION_RE.replace_all(&out, "@(role)").to_string();
out.trim().to_string()
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn bot_turns_increment() {
let mut t = BotTurnTracker::new(5);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
}

#[test]
fn soft_limit_triggers() {
let mut t = BotTurnTracker::new(3);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3));
}

#[test]
fn human_resets_both_counters() {
let mut t = BotTurnTracker::new(3);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
t.on_human_message("t1");
// Both reset — can do 2 more before soft limit
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3));
}

#[test]
fn hard_limit_triggers() {
let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1);
for _ in 0..HARD_BOT_TURN_LIMIT - 1 {
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
}
assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit);
}

#[test]
fn hard_limit_resets_on_human() {
let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT + 1);
for _ in 0..HARD_BOT_TURN_LIMIT - 1 {
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
}
t.on_human_message("t1");
// Hard counter reset — can go again
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
}

#[test]
fn hard_before_soft_when_equal() {
let mut t = BotTurnTracker::new(HARD_BOT_TURN_LIMIT);
for _ in 0..HARD_BOT_TURN_LIMIT - 1 {
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
}
// soft == hard == HARD_BOT_TURN_LIMIT → hard wins
assert_eq!(t.on_bot_message("t1"), TurnResult::HardLimit);
}

#[test]
fn threads_are_independent() {
let mut t = BotTurnTracker::new(3);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::Ok);
assert_eq!(t.on_bot_message("t1"), TurnResult::SoftLimit(3));
// t2 is unaffected
assert_eq!(t.on_bot_message("t2"), TurnResult::Ok);
}

#[test]
fn human_on_unknown_thread_is_noop() {
let mut t = BotTurnTracker::new(5);
t.on_human_message("unknown"); // should not panic
}
}
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async fn main() -> anyhow::Result<()> {
multibot_threads: tokio::sync::Mutex::new(std::collections::HashMap::new()),
session_ttl: std::time::Duration::from_secs(ttl_secs),
max_bot_turns: discord_cfg.max_bot_turns,
bot_turn_counts: tokio::sync::Mutex::new(std::collections::HashMap::new()),
bot_turns: tokio::sync::Mutex::new(discord::BotTurnTracker::new(discord_cfg.max_bot_turns)),
};

let intents = GatewayIntents::GUILD_MESSAGES
Expand Down
Loading