diff --git a/Cargo.lock b/Cargo.lock index 6b016571..11a00e04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,56 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" + +[[package]] +name = "anstyle-parse" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -128,12 +178,58 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "clap" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" + [[package]] name = "color_quant" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" +[[package]] +name = "colorchoice" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" + [[package]] name = "cpufeatures" version = "0.2.17" @@ -493,15 +589,14 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.7" +version = "0.27.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +checksum = "c2b52f86d1d4bc0d6b4e6826d960b1b333217e07d36b882dca570a5e1c48895b" dependencies = [ "http", "hyper", "hyper-util", - "rustls 0.23.37", - "rustls-pki-types", + "rustls 0.23.38", "tokio", "tokio-rustls 0.26.4", "tower-service", @@ -696,6 +791,12 @@ dependencies = [ "serde", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itoa" version = "1.0.18" @@ -851,17 +952,25 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "openab" -version = "0.7.3" +version = "0.7.4" dependencies = [ "anyhow", "base64", + "clap", "image", "libc", "rand 0.8.5", "regex", "reqwest", + "rpassword", "serde", "serde_json", "serenity", @@ -869,6 +978,7 @@ dependencies = [ "toml", "tracing", "tracing-subscriber", + "unicode-width", "uuid", ] @@ -987,7 +1097,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.23.37", + "rustls 0.23.38", "socket2", "thiserror 2.0.18", "tokio", @@ -1004,10 +1114,10 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.2", + "rand 0.9.3", "ring", "rustc-hash", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-pki-types", "slab", "thiserror 2.0.18", @@ -1064,9 +1174,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.2" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.5", @@ -1156,6 +1266,7 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64", "bytes", + "futures-channel", "futures-core", "futures-util", "http", @@ -1170,7 +1281,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.37", + "rustls 0.23.38", "rustls-pki-types", "serde", "serde_json", @@ -1204,6 +1315,27 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rpassword" +version = "7.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66d4c8b64f049c6721ec8ccec37ddfc3d641c4a7fca57e8f2a89de509c73df39" +dependencies = [ + "libc", + "rtoolbox", + "windows-sys 0.59.0", +] + +[[package]] +name = "rtoolbox" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327b72899159dfae8060c51a1f6aebe955245bcd9cc4997eed0f623caea022e4" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "rustc-hash" version = "2.1.2" @@ -1226,9 +1358,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.37" +version = "0.23.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +checksum = "69f9466fb2c14ea04357e91413efb882e2a6d4a406e625449bc0a5d360d53a21" dependencies = [ "once_cell", "ring", @@ -1478,6 +1610,12 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "subtle" version = "2.6.1" @@ -1665,7 +1803,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" dependencies = [ - "rustls 0.23.37", + "rustls 0.23.38", "tokio", ] @@ -1910,6 +2048,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unicode-xid" version = "0.2.6" @@ -1947,6 +2091,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.23.0" @@ -2164,6 +2314,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" diff --git a/Cargo.toml b/Cargo.toml index 26abc396..33c1dac0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,10 @@ uuid = { version = "1", features = ["v4"] } regex = "1" anyhow = "1" rand = "0.8" -reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json"] } +clap = { version = "4", features = ["derive"] } +rpassword = "7" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json", "blocking"] } base64 = "0.22" image = { version = "0.25", default-features = false, features = ["jpeg", "png", "gif", "webp"] } +unicode-width = "0.2" libc = "0.2" diff --git a/src/main.rs b/src/main.rs index fd63b89a..59330ab3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,14 +4,39 @@ mod discord; mod error_display; mod format; mod reactions; +mod setup; mod stt; +use clap::Parser; use serenity::prelude::*; use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; use tracing::info; +#[derive(Parser)] +#[command(name = "openab")] +#[command(about = "Discord bot that manages ACP agent sessions", long_about = None)] +struct Cli { + #[command(subcommand)] + command: Option, +} + +#[derive(clap::Subcommand)] +enum Commands { + /// Run the bot (default) + Run { + /// Config file path (default: config.toml) + config: Option, + }, + /// Launch the interactive setup wizard + Setup { + /// Output file path for generated config (default: config.toml) + #[arg(short, long)] + output: Option, + }, +} + #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() @@ -21,90 +46,99 @@ async fn main() -> anyhow::Result<()> { ) .init(); - let config_path = std::env::args() - .nth(1) - .map(PathBuf::from) - .unwrap_or_else(|| PathBuf::from("config.toml")); - - let mut cfg = config::load_config(&config_path)?; - info!( - agent_cmd = %cfg.agent.command, - pool_max = cfg.pool.max_sessions, - channels = ?cfg.discord.allowed_channels, - users = ?cfg.discord.allowed_users, - reactions = cfg.reactions.enabled, - allow_bot_messages = ?cfg.discord.allow_bot_messages, - "config loaded" - ); - - let pool = Arc::new(acp::SessionPool::new(cfg.agent, cfg.pool.max_sessions)); - let ttl_secs = cfg.pool.session_ttl_hours * 3600; - - let allowed_channels = parse_id_set(&cfg.discord.allowed_channels, "allowed_channels")?; - let allowed_users = parse_id_set(&cfg.discord.allowed_users, "allowed_users")?; - let trusted_bot_ids = parse_id_set(&cfg.discord.trusted_bot_ids, "trusted_bot_ids")?; - info!(channels = allowed_channels.len(), users = allowed_users.len(), trusted_bots = ?trusted_bot_ids, "parsed allowlists"); - - // Resolve STT config before constructing handler (auto-detect mutates cfg.stt) - if cfg.stt.enabled { - if cfg.stt.api_key.is_empty() && cfg.stt.base_url.contains("groq.com") { - if let Ok(key) = std::env::var("GROQ_API_KEY") { - if !key.is_empty() { - info!("stt.api_key not set, using GROQ_API_KEY from environment"); - cfg.stt.api_key = key; + let cmd = Cli::parse().command.unwrap_or(Commands::Run { config: None }); + + match cmd { + Commands::Setup { output } => { + setup::run_setup(output.map(PathBuf::from))?; + return Ok(()); + } + Commands::Run { config } => { + let config_path = config + .map(PathBuf::from) + .unwrap_or_else(|| PathBuf::from("config.toml")); + + let mut cfg = config::load_config(&config_path)?; + info!( + agent_cmd = %cfg.agent.command, + pool_max = cfg.pool.max_sessions, + channels = ?cfg.discord.allowed_channels, + users = ?cfg.discord.allowed_users, + reactions = cfg.reactions.enabled, + allow_bot_messages = ?cfg.discord.allow_bot_messages, + "config loaded" + ); + + let pool = Arc::new(acp::SessionPool::new(cfg.agent, cfg.pool.max_sessions)); + let ttl_secs = cfg.pool.session_ttl_hours * 3600; + + let allowed_channels = parse_id_set(&cfg.discord.allowed_channels, "allowed_channels")?; + let allowed_users = parse_id_set(&cfg.discord.allowed_users, "allowed_users")?; + let trusted_bot_ids = parse_id_set(&cfg.discord.trusted_bot_ids, "trusted_bot_ids")?; + info!(channels = allowed_channels.len(), users = allowed_users.len(), trusted_bots = ?trusted_bot_ids, "parsed allowlists"); + + // Resolve STT config before constructing handler (auto-detect mutates cfg.stt) + if cfg.stt.enabled { + if cfg.stt.api_key.is_empty() && cfg.stt.base_url.contains("groq.com") { + if let Ok(key) = std::env::var("GROQ_API_KEY") { + if !key.is_empty() { + info!("stt.api_key not set, using GROQ_API_KEY from environment"); + cfg.stt.api_key = key; + } + } + } + if cfg.stt.api_key.is_empty() { + anyhow::bail!("stt.enabled = true but no API key found — set stt.api_key in config or export GROQ_API_KEY"); } + info!(model = %cfg.stt.model, base_url = %cfg.stt.base_url, "STT enabled"); } - } - if cfg.stt.api_key.is_empty() { - anyhow::bail!("stt.enabled = true but no API key found — set stt.api_key in config or export GROQ_API_KEY"); - } - info!(model = %cfg.stt.model, base_url = %cfg.stt.base_url, "STT enabled"); - } - let handler = discord::Handler { - pool: pool.clone(), - allowed_channels, - allowed_users, - reactions_config: cfg.reactions, - stt_config: cfg.stt.clone(), - allow_bot_messages: cfg.discord.allow_bot_messages, - trusted_bot_ids, - }; - - let intents = GatewayIntents::GUILD_MESSAGES - | GatewayIntents::MESSAGE_CONTENT - | GatewayIntents::GUILDS; - - let mut client = Client::builder(&cfg.discord.bot_token, intents) - .event_handler(handler) - .await?; - - // Spawn cleanup task - let cleanup_pool = pool.clone(); - let cleanup_handle = tokio::spawn(async move { - loop { - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - cleanup_pool.cleanup_idle(ttl_secs).await; + let handler = discord::Handler { + pool: pool.clone(), + allowed_channels, + allowed_users, + reactions_config: cfg.reactions, + stt_config: cfg.stt.clone(), + allow_bot_messages: cfg.discord.allow_bot_messages, + trusted_bot_ids, + }; + + let intents = GatewayIntents::GUILD_MESSAGES + | GatewayIntents::MESSAGE_CONTENT + | GatewayIntents::GUILDS; + + let mut client = Client::builder(&cfg.discord.bot_token, intents) + .event_handler(handler) + .await?; + + // Spawn cleanup task + let cleanup_pool = pool.clone(); + let cleanup_handle = tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + cleanup_pool.cleanup_idle(ttl_secs).await; + } + }); + + // Run bot until SIGINT/SIGTERM + let shard_manager = client.shard_manager.clone(); + let shutdown_pool = pool.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + info!("shutdown signal received"); + shard_manager.shutdown_all().await; + }); + + info!("starting discord bot"); + client.start().await?; + + // Cleanup + cleanup_handle.abort(); + shutdown_pool.shutdown().await; + info!("openab shut down"); + Ok(()) } - }); - - // Run bot until SIGINT/SIGTERM - let shard_manager = client.shard_manager.clone(); - let shutdown_pool = pool.clone(); - tokio::spawn(async move { - tokio::signal::ctrl_c().await.ok(); - info!("shutdown signal received"); - shard_manager.shutdown_all().await; - }); - - info!("starting discord bot"); - client.start().await?; - - // Cleanup - cleanup_handle.abort(); - shutdown_pool.shutdown().await; - info!("openab shut down"); - Ok(()) + } } fn parse_id_set(raw: &[String], label: &str) -> anyhow::Result> { diff --git a/src/setup/config.rs b/src/setup/config.rs new file mode 100644 index 00000000..21d65e7e --- /dev/null +++ b/src/setup/config.rs @@ -0,0 +1,167 @@ +//! Config generation and TOML serialization for the setup wizard. + +/// Mask bot token in config output for preview +pub fn mask_bot_token(config: &str) -> String { + config + .lines() + .map(|line| { + if line.trim_start().starts_with("bot_token") { + "bot_token = \"***\"".to_string() + } else { + line.to_string() + } + }) + .collect::>() + .join("\n") +} + +#[derive(serde::Serialize)] +pub(crate) struct ConfigToml { + discord: DiscordConfigToml, + agent: AgentConfigToml, + pool: PoolConfigToml, + reactions: ReactionsConfigToml, +} + +#[derive(serde::Serialize)] +struct DiscordConfigToml { + bot_token: String, + allowed_channels: Vec, +} + +#[derive(serde::Serialize)] +struct AgentConfigToml { + command: String, + args: Vec, + working_dir: String, +} + +#[derive(serde::Serialize)] +struct PoolConfigToml { + max_sessions: usize, + session_ttl_hours: u64, +} + +#[derive(serde::Serialize)] +struct ReactionsConfigToml { + enabled: bool, + remove_after_reply: bool, + emojis: EmojisToml, + timing: TimingToml, +} + +#[derive(serde::Serialize)] +struct EmojisToml { + queued: String, + thinking: String, + tool: String, + coding: String, + web: String, + done: String, + error: String, +} + +#[derive(serde::Serialize)] +struct TimingToml { + debounce_ms: u64, + stall_soft_ms: u64, + stall_hard_ms: u64, + done_hold_ms: u64, + error_hold_ms: u64, +} + +pub fn generate_config( + bot_token: &str, + agent_command: &str, + channel_ids: Vec, + working_dir: &str, + max_sessions: usize, + session_ttl_hours: u64, +) -> String { + let config = ConfigToml { + discord: DiscordConfigToml { + bot_token: bot_token.to_string(), + allowed_channels: channel_ids, + }, + agent: { + let (command, args): (&str, Vec) = match agent_command { + "kiro" => ( + "kiro-cli", + vec!["acp".into(), "--trust-all-tools".into()], + ), + "claude" => ("claude-agent-acp", vec![]), + "codex" => ("codex-acp", vec![]), + "gemini" => ("gemini", vec!["--acp".into()]), + other => (other, vec![]), + }; + AgentConfigToml { + command: command.to_string(), + args, + working_dir: working_dir.to_string(), + } + }, + pool: PoolConfigToml { + max_sessions, + session_ttl_hours, + }, + reactions: ReactionsConfigToml { + enabled: true, + remove_after_reply: false, + emojis: EmojisToml { + queued: "👀".into(), + thinking: "🤔".into(), + tool: "🔥".into(), + coding: "👨💻".into(), + web: "⚡".into(), + done: "🆗".into(), + error: "😱".into(), + }, + timing: TimingToml { + debounce_ms: 700, + stall_soft_ms: 10_000, + stall_hard_ms: 30_000, + done_hold_ms: 1_500, + error_hold_ms: 2_500, + }, + }, + }; + toml::to_string_pretty(&config).expect("TOML serialization failed") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_config_contains_sections() { + let config = generate_config( + "my_token", + "claude", + vec!["123".to_string()], + "/home/agent", + 10, + 24, + ); + assert!(config.contains("[discord]")); + assert!(config.contains("[agent]")); + assert!(config.contains("[pool]")); + assert!(config.contains("[reactions]")); + assert!(config.contains("[reactions.emojis]")); + assert!(config.contains("[reactions.timing]")); + } + + #[test] + fn test_generate_config_kiro_working_dir() { + let config = generate_config( + "tok", + "kiro", + vec!["ch".to_string()], + "/home/agent", + 10, + 24, + ); + assert!(config.contains(r#"working_dir = "/home/agent""#)); + assert!(config.contains("acp")); + assert!(config.contains("--trust-all-tools")); + } +} diff --git a/src/setup/mod.rs b/src/setup/mod.rs new file mode 100644 index 00000000..96034f0a --- /dev/null +++ b/src/setup/mod.rs @@ -0,0 +1,12 @@ +//! OpenAB interactive setup wizard. +//! +//! Modules: +//! - `validate` — input validation (bot token, channel ID, agent command) +//! - `config` — TOML config generation and serialization +//! - `wizard` — interactive TUI, Discord API client, and wizard entry point + +mod config; +mod validate; +mod wizard; + +pub use wizard::run_setup; diff --git a/src/setup/validate.rs b/src/setup/validate.rs new file mode 100644 index 00000000..247b1b9a --- /dev/null +++ b/src/setup/validate.rs @@ -0,0 +1,73 @@ +//! Input validation functions for the setup wizard. + +/// Validate bot token format using allowlist (a-zA-Z0-9-./_) +pub fn validate_bot_token(token: &str) -> anyhow::Result<()> { + if token.is_empty() { + anyhow::bail!("Token cannot be empty"); + } + if !token + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '/' || c == '*' || c == '=') + { + anyhow::bail!( + "Token must only contain ASCII letters, numbers, dash, period, underscore, slash, or equals" + ); + } + Ok(()) +} + +/// Validate agent command +#[cfg(test)] +pub fn validate_agent_command(cmd: &str) -> anyhow::Result<()> { + let valid = ["kiro", "claude", "codex", "gemini"]; + if !valid.contains(&cmd) { + anyhow::bail!("Agent must be one of: {}", valid.join(", ")); + } + Ok(()) +} + +/// Validate channel ID is numeric +pub fn validate_channel_id(id: &str) -> anyhow::Result<()> { + if id.is_empty() { + anyhow::bail!("Channel ID cannot be empty"); + } + if !id.chars().all(|c| c.is_ascii_digit()) { + anyhow::bail!("Channel ID must be numeric only"); + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_bot_token_ok() { + assert!(validate_bot_token("simple_token").is_ok()); + assert!(validate_bot_token("token.with-dashes_123").is_ok()); + assert!(validate_bot_token("***/efgh").is_ok()); + } + + #[test] + fn test_validate_bot_token_reject_invalid() { + assert!(validate_bot_token("").is_err()); + assert!(validate_bot_token("token\nnewline").is_err()); + assert!(validate_bot_token("token\ttab").is_err()); + assert!(validate_bot_token("token with space").is_err()); + } + + #[test] + fn test_validate_agent_command() { + for agent in &["kiro", "claude", "codex", "gemini"] { + assert!(validate_agent_command(agent).is_ok()); + } + assert!(validate_agent_command("invalid").is_err()); + } + + #[test] + fn test_validate_channel_id() { + assert!(validate_channel_id("1492329565824094370").is_ok()); + assert!(validate_channel_id("").is_err()); + assert!(validate_channel_id("abc123").is_err()); + } +} diff --git a/src/setup/wizard.rs b/src/setup/wizard.rs new file mode 100644 index 00000000..8a346400 --- /dev/null +++ b/src/setup/wizard.rs @@ -0,0 +1,676 @@ +//! Interactive setup wizard TUI and Discord API client. + +use std::io::{self, IsTerminal, Write}; +use std::path::{Path, PathBuf}; + +use crate::setup::config::{generate_config, mask_bot_token}; +use crate::setup::validate::{validate_bot_token, validate_channel_id}; + +// --------------------------------------------------------------------------- +// Color codes (ANSI) +// --------------------------------------------------------------------------- + +const C: Colors = Colors { + reset: "\x1b[0m", + bold: "\x1b[1m", + cyan: "\x1b[36m", + green: "\x1b[32m", + red: "\x1b[31m", + yellow: "\x1b[33m", + magenta: "\x1b[35m", +}; + +struct Colors { + reset: &'static str, + bold: &'static str, + cyan: &'static str, + green: &'static str, + red: &'static str, + yellow: &'static str, + magenta: &'static str, +} + +const BORDER: char = '═'; + +macro_rules! cprintln { + ($color:expr, $fmt:expr) => {{ + println!("{}{}{}", $color, $fmt, C.reset); + }}; + ($color:expr, $fmt:expr, $($arg:tt)*) => {{ + println!("{}{}{}", $color, format!($fmt, $($arg)*), C.reset); + }}; +} + +// --------------------------------------------------------------------------- +// Input helpers +// --------------------------------------------------------------------------- + +fn is_interactive() -> bool { + std::io::stdin().is_terminal() && std::io::stdout().is_terminal() +} + +fn prompt(prompt_text: &str) -> String { + print!("{}{}: {}", C.yellow, prompt_text, C.reset); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + input.trim().to_string() +} + +fn prompt_default(prompt_text: &str, default: &str) -> String { + print!("{}{} [{}]: {}", C.yellow, prompt_text, default, C.reset); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim(); + if input.is_empty() { + default.to_string() + } else { + input.to_string() + } +} + +fn prompt_password(prompt_text: &str) -> String { + print!("{}{}: ", C.yellow, prompt_text); + io::stdout().flush().ok(); + rpassword::read_password().unwrap_or_default() +} + +fn prompt_yes_no(prompt_text: &str, default: bool) -> bool { + let default_str = if default { "Y/n" } else { "y/N" }; + loop { + print!("{}{} [{}]: ", C.yellow, prompt_text, default_str,); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim().to_lowercase(); + if input.is_empty() { + return default; + } + match input.as_str() { + "y" | "yes" => return true, + "n" | "no" => return false, + _ => cprintln!(C.red, "Please enter 'y' or 'n'"), + } + } +} + +fn prompt_choice(prompt_text: &str, choices: &[&str]) -> usize { + println!(); + cprintln!(C.cyan, "{}", prompt_text); + for (i, choice) in choices.iter().enumerate() { + println!(" {}. {}", i + 1, choice); + } + print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); + io::stdout().flush().ok(); + loop { + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + match input.trim().parse::() { + Ok(n) if n >= 1 && n <= choices.len() => return n - 1, + _ => { + print!("{}Select [1-{}]: {}", C.yellow, choices.len(), C.reset); + io::stdout().flush().ok(); + } + } + } +} + +fn prompt_checklist(prompt_text: &str, items: &[&str]) -> Vec { + println!(); + cprintln!(C.cyan, "{}", prompt_text); + for (i, item) in items.iter().enumerate() { + println!(" [{}] {}", i + 1, item); + } + println!(); + print!( + "{}Enter numbers separated by commas (e.g. 1,3,5) or press Enter for all: {}", + C.yellow, C.reset + ); + io::stdout().flush().ok(); + let mut input = String::new(); + io::stdin().read_line(&mut input).ok(); + let input = input.trim(); + if input.is_empty() { + return (0..items.len()).collect(); + } + input + .split(',') + .filter_map(|s| s.trim().parse::().ok()) + .filter(|n| *n >= 1 && *n <= items.len()) + .map(|n| n - 1) + .collect() +} + +// --------------------------------------------------------------------------- +// Box drawing helpers +// --------------------------------------------------------------------------- + +fn print_box(lines: &[&str]) { + let width = lines + .iter() + .map(|l| unicode_width::UnicodeWidthStr::width(&**l)) + .max() + .unwrap_or(60); + let width = width.clamp(60, 76); + println!(); + cprintln!(C.cyan, "{}", "╔".to_string() + &BORDER.to_string().repeat(width + 2) + "╗"); + for line in lines { + let padded = format!(" {: Self { + Self { + token: token.to_string(), + http: reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(10)) + .build() + .expect("static HTTP client must build"), + } + } + + /// Verify token by fetching bot info + fn verify_token(&self) -> anyhow::Result<(String, String)> { + let resp = self + .http + .get("https://discord.com/api/v10/users/@me") + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Token verification failed: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct MeResponse { + id: String, + username: String, + } + let me: MeResponse = resp.json()?; + Ok((me.id, me.username)) + } + + /// Fetch guilds the bot is in + fn fetch_guilds(&self) -> anyhow::Result> { + let resp = self + .http + .get("https://discord.com/api/v10/users/@me/guilds") + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Failed to fetch guilds: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct Guild { + id: String, + name: String, + } + let guilds: Vec = resp.json()?; + Ok(guilds.into_iter().map(|g| (g.id, g.name)).collect()) + } + + /// Fetch channels in a guild + fn fetch_channels(&self, guild_id: &str) -> anyhow::Result> { + let url = format!("https://discord.com/api/v10/guilds/{}/channels", guild_id); + let resp = self + .http + .get(&url) + .header("Authorization", format!("Bot {}", self.token)) + .header("User-Agent", "OpenAB setup wizard") + .send()?; + if !resp.status().is_success() { + anyhow::bail!("Failed to fetch channels: HTTP {}", resp.status()); + } + #[derive(serde::Deserialize)] + struct Channel { + id: String, + #[serde(rename = "type")] + kind: u8, + name: String, + } + let channels: Vec = resp.json()?; + // type 0 = text channel + Ok(channels + .into_iter() + .filter(|c| c.kind == 0) + .map(|c| (c.id, c.name, guild_id.to_string())) + .collect()) + } +} + +// --------------------------------------------------------------------------- +// Section 1: Discord Bot Setup Guide +// --------------------------------------------------------------------------- + +fn section_discord_guide() { + print_box(&[ + "Discord Bot Setup Guide", + "", + "1. Go to: https://discord.com/developers/applications", + "2. Click 'New Application' -> name it (e.g. OpenAB)", + "3. Bot -> Reset Token -> COPY the token", + "", + "4. Enable Privileged Gateway Intents:", + " - Message Content Intent", + " - Guild Members Intent", + "", + "5. OAuth2 -> URL Generator:", + " - SCOPES: bot", + " - BOT PERMISSIONS:", + " Send Messages | Embed Links | Attach Files", + " Read Message History | Add Reactions", + " Use Slash Commands", + "", + "6. Visit the generated URL -> add bot to your server", + ]); +} + +// --------------------------------------------------------------------------- +// Section 2: Channel Selection +// --------------------------------------------------------------------------- + +fn section_channels(client: &DiscordClient) -> anyhow::Result> { + println!(); + cprintln!(C.bold, "--- Step 2: Allowed Channels ---"); + println!(); + + print!(" Fetching servers... "); + io::stdout().flush().ok(); + let guilds = client.fetch_guilds()?; + cprintln!(C.green, "OK Found {} server(s)", guilds.len()); + println!(); + + if guilds.is_empty() { + cprintln!( + C.yellow, + " No servers found. Enter channel IDs manually." + ); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id)?; + } + return Ok(ids); + } + + let guild_names: Vec<&str> = guilds.iter().map(|(_, n)| n.as_str()).collect(); + let guild_idx = prompt_choice(" Select server:", &guild_names); + let (guild_id, guild_name) = &guilds[guild_idx]; + + print!(" Fetching channels in '{}'... ", guild_name); + io::stdout().flush().ok(); + let channels = client.fetch_channels(guild_id)?; + cprintln!(C.green, "OK Found {} channel(s)", channels.len()); + println!(); + + if channels.is_empty() { + cprintln!( + C.yellow, + " No text channels found. Enter channel IDs manually." + ); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id)?; + } + return Ok(ids); + } + + let channel_names: Vec = channels + .iter() + .map(|(_, n, _)| format!("#{}", n)) + .collect(); + let channel_names_refs: Vec<&str> = channel_names + .iter() + .map(|s| s.as_str()) + .collect(); + + let selected = + prompt_checklist(" Select channels (by number):", &channel_names_refs); + let selected_ids: Vec = selected + .iter() + .map(|&i| channels[i].0.clone()) + .collect(); + + println!(); + cprintln!(C.green, " Selected {} channel(s)", selected_ids.len()); + for id in &selected_ids { + if let Some((_, name, _)) = channels.iter().find(|(cid, _, _)| cid == id) { + println!(" * #{}", name); + } else { + println!(" * {}", id); + } + } + println!(); + + Ok(selected_ids) +} + +// --------------------------------------------------------------------------- +// Section 3: Agent Configuration +// --------------------------------------------------------------------------- + +fn section_agent() -> (String, String, bool) { + println!(); + cprintln!(C.bold, "--- Step 3: Agent Configuration ---"); + println!(); + + print_box(&[ + "Agent Installation Guide", + "", + "claude: npm install -g @anthropic-ai/claude-code", + "kiro: npm install -g @koryhutchison/kiro-cli", + "codex: npm install -g openai-codex (requires OpenAI API key)", + "gemini: npm install -g @google/gemini-cli", + "", + "Make sure the agent is in your PATH before continuing.", + ]); + println!(); + + let choices = ["claude", "kiro", "codex", "gemini"]; + let idx = prompt_choice(" Select agent:", &choices); + let agent = choices[idx]; + + let deploy_choices = ["Local (current directory)", "Docker / k8s"]; + let deploy_idx = prompt_choice(" Deployment target:", &deploy_choices); + let is_local = deploy_idx == 0; + let default_dir = match (is_local, agent) { + (true, _) => ".", + (false, "kiro") => "/home/agent", + (false, _) => "/home/node", + }; + + let working_dir = prompt_default(" Working directory", default_dir); + + cprintln!( + C.green, + " Agent: {} | Working dir: {}", + agent, + working_dir + ); + println!(); + + (agent.to_string(), working_dir, is_local) +} + +// --------------------------------------------------------------------------- +// Section 4: Pool Settings +// --------------------------------------------------------------------------- + +fn section_pool() -> (usize, u64) { + println!(); + cprintln!(C.bold, "--- Step 4: Session Pool ---"); + println!(); + + let max_sessions: usize = prompt_default(" Max sessions", "10") + .parse() + .unwrap_or(10); + let ttl_hours: u64 = prompt_default(" Session TTL (hours)", "24") + .parse() + .unwrap_or(24); + + cprintln!( + C.green, + " Max sessions: {} | TTL: {}h", + max_sessions, + ttl_hours + ); + println!(); + + (max_sessions, ttl_hours) +} + +// --------------------------------------------------------------------------- +// Preview & Save +// --------------------------------------------------------------------------- + +fn section_preview_and_save(config_content: &str, output_path: &PathBuf) -> anyhow::Result<()> { + println!(); + cprintln!(C.bold, "--- Preview ---"); + println!(); + println!("{}", mask_bot_token(config_content)); + println!(); + + if output_path.exists() + && !prompt_yes_no(" File exists. Overwrite?", false) + { + println!(" Saving cancelled."); + return Ok(()); + } + + std::fs::write(output_path, config_content)?; + cprintln!(C.green, "OK config.toml saved to {}", output_path.display()); + println!(); + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Non-interactive guidance +// --------------------------------------------------------------------------- + +fn print_noninteractive_guide() { + print_box(&[ + "Non-Interactive Mode", + "", + "The interactive wizard requires a terminal.", + "Create config.toml manually, then run:", + "", + " openab run config.toml", + "", + "Config format reference:", + " [discord]", + " bot_token = \"YOUR_BOT_TOKEN\"", + " allowed_channels = [\"CHANNEL_ID\"]", + "", + " [agent]", + " command = \"kiro-cli\"", + " args = [\"acp\", \"--trust-all-tools\"]", + " working_dir = \"/home/agent\"", + "", + " [pool]", + " max_sessions = 10", + " session_ttl_hours = 24", + "", + " [reactions]", + " enabled = true", + " remove_after_reply = false", + " ...", + ]); +} + +// --------------------------------------------------------------------------- +// Next steps printer +// --------------------------------------------------------------------------- + +fn print_next_steps(agent: &str, output_path: &Path, is_local: bool) { + println!(); + cprintln!(C.bold, "--- Next Steps ---"); + println!(); + + if is_local { + match agent { + "kiro" => { + cprintln!(C.cyan, " 1. Install kiro-cli (see https://kiro.dev for installer)"); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" kiro-cli login --use-device-flow"); + } + "claude" => { + cprintln!(C.cyan, " 1. Install Claude Code + ACP adapter:"); + println!(" npm install -g @anthropic-ai/claude-code @agentclientprotocol/claude-agent-acp"); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" claude setup-token"); + } + "codex" => { + cprintln!(C.cyan, " 1. Install Codex CLI + ACP adapter:"); + println!(" npm install -g @openai/codex @zed-industries/codex-acp"); + cprintln!(C.cyan, " 2. Authenticate:"); + println!(" codex login --device-auth"); + } + "gemini" => { + cprintln!(C.cyan, " 1. Install Gemini CLI:"); + println!(" npm install -g @google/gemini-cli"); + cprintln!(C.cyan, " 2. Authenticate via Google OAuth, or set GEMINI_API_KEY in config.toml"); + } + _ => {} + } + + println!(); + cprintln!(C.green, " 3. Run the bot:"); + println!(" cargo run -- run {}", output_path.display()); + } else { + cprintln!( + C.cyan, + " Docker image already bundles the agent CLI and ACP adapter." + ); + println!(); + cprintln!(C.cyan, " 1. Deploy with Helm (or your preferred method):"); + println!(" helm install openab openab/openab \\"); + println!(" --set agents.{}.discord.botToken=\"$BOT_TOKEN\"", agent); + println!(); + cprintln!(C.cyan, " 2. Authenticate inside the pod (first time only):"); + match agent { + "kiro" => println!( + " kubectl exec -it deployment/openab-kiro -- kiro-cli login --use-device-flow" + ), + "claude" => println!( + " kubectl exec -it deployment/openab-claude -- claude setup-token" + ), + "codex" => println!( + " kubectl exec -it deployment/openab-codex -- codex login --device-auth" + ), + "gemini" => println!( + " Set GEMINI_API_KEY via secret, or exec into the pod for OAuth" + ), + _ => {} + } + println!(); + cprintln!(C.green, " See README for full Helm options."); + } + println!(); +} + +// --------------------------------------------------------------------------- +// Main wizard entry point +// --------------------------------------------------------------------------- + +pub fn run_setup(output_path: Option) -> anyhow::Result<()> { + if !is_interactive() { + print_noninteractive_guide(); + return Ok(()); + } + + println!(); + cprintln!( + C.magenta, + "============================================================" + ); + cprintln!( + C.magenta, + " OpenAB Interactive Setup Wizard " + ); + cprintln!( + C.magenta, + "============================================================" + ); + + // Step 1: Discord Guide + Token + section_discord_guide(); + println!(); + let bot_token = prompt_password(" Bot Token (or press Enter to skip)"); + if bot_token.is_empty() { + cprintln!( + C.yellow, + " Skipped. Set bot_token manually in config.toml" + ); + println!(); + cprintln!( + C.green, + " Setup complete! Edit config.toml to add your bot token." + ); + return Ok(()); + } + validate_bot_token(&bot_token)?; + + let client = DiscordClient::new(&bot_token); + print!(" Verifying token with Discord API... "); + io::stdout().flush().ok(); + let (_bot_id, bot_username) = client.verify_token()?; + cprintln!(C.green, "OK Logged in as {}", bot_username); + + // Step 2: Channels + let channel_ids = match section_channels(&client) { + Ok(ids) if !ids.is_empty() => ids, + Ok(_) => { + cprintln!(C.yellow, " No channels selected."); + vec![] + } + Err(e) => { + cprintln!( + C.yellow, + " Channel fetch failed: {}. Enter manually.", + e + ); + let input = prompt(" Channel ID(s), comma-separated"); + let ids: Vec = input + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect(); + for id in &ids { + validate_channel_id(id).map_err(|e| anyhow::anyhow!("{}", e))?; + } + ids + } + }; + + // Step 3: Agent + let (agent, working_dir, is_local) = section_agent(); + + // Step 4: Pool + let (max_sessions, ttl_hours) = section_pool(); + + // Generate + let config_content = generate_config( + &bot_token, + &agent, + channel_ids, + &working_dir, + max_sessions, + ttl_hours, + ); + + // Output + let output_path = output_path.unwrap_or_else(|| PathBuf::from("config.toml")); + section_preview_and_save(&config_content, &output_path)?; + + print_next_steps(&agent, &output_path, is_local); + + Ok(()) +}