diff --git a/Cargo.lock b/Cargo.lock index c2a3212aa3..ea0cda152c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1442,7 +1442,7 @@ dependencies = [ "downcast-rs", "libsqlite3-sys", "r2d2", - "sqlite-wasm-rs", + "sqlite-wasm-rs 0.4.8", "time", ] @@ -1775,6 +1775,18 @@ dependencies = [ "rand 0.10.0", ] +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fastrand" version = "2.3.0" @@ -1866,6 +1878,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "forge_api" version = "0.1.0" @@ -2119,6 +2137,7 @@ dependencies = [ "console 0.16.3", "convert_case 0.11.0", "derive_setters", + "dirs", "enable-ansi-support", "fake", "forge_api", @@ -2145,6 +2164,7 @@ dependencies = [ "open", "pretty_assertions", "reedline", + "rusqlite", "rustls 0.23.37", "serde", "serde_json", @@ -2873,7 +2893,7 @@ version = "0.15.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -2881,6 +2901,9 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "foldhash 0.2.0", +] [[package]] name = "hashlink" @@ -2891,6 +2914,15 @@ dependencies = [ "hashbrown 0.15.5", ] +[[package]] +name = "hashlink" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea0b22561a9c04a7cb1a302c013e0259cd3b4bb619f145b32f72b8b4bcbed230" +dependencies = [ + "hashbrown 0.16.1", +] + [[package]] name = "heck" version = "0.5.0" @@ -5262,6 +5294,31 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "rsqlite-vfs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8a1f2315036ef6b1fbacd1972e8ee7688030b0a2121edfc2a6550febd41574d" +dependencies = [ + "hashbrown 0.16.1", + "thiserror 2.0.18", +] + +[[package]] +name = "rusqlite" +version = "0.38.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c93dd1c9683b438c392c492109cb702b8090b2bfc8fed6f6e4eb4523f17af3" +dependencies = [ + "bitflags 2.10.0", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink 0.11.0", + "libsqlite3-sys", + "smallvec", + "sqlite-wasm-rs 0.5.2", +] + [[package]] name = "rust-ini" version = "0.21.3" @@ -5922,6 +5979,18 @@ dependencies = [ "web-sys", ] +[[package]] +name = "sqlite-wasm-rs" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f4206ed3a67690b9c29b77d728f6acc3ce78f16bf846d83c94f76400320181b" +dependencies = [ + "cc", + "js-sys", + "rsqlite-vfs", + "wasm-bindgen", +] + [[package]] name = "sse-stream" version = "0.2.1" @@ -8034,7 +8103,7 @@ checksum = "2462ea039c445496d8793d052e13787f2b90e750b833afee748e601c17621ed9" dependencies = [ "arraydeque", "encoding_rs", - "hashlink", + "hashlink 0.10.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9a9f43786a..11d0475359 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,6 +119,7 @@ uuid = { version = "1.22.0", features = [ whoami = "2.1.0" fnv_rs = "0.4.3" merge = { version = "0.2", features = ["derive"] } +rusqlite = { version = "0.38", features = ["bundled"] } rmcp = { version = "0.10.0", features = [ "client", "transport-sse-client-reqwest", diff --git a/crates/forge_main/Cargo.toml b/crates/forge_main/Cargo.toml index a96637b215..5393d690ef 100644 --- a/crates/forge_main/Cargo.toml +++ b/crates/forge_main/Cargo.toml @@ -54,6 +54,8 @@ open.workspace = true humantime.workspace = true num-format.workspace = true atty = "0.2" +dirs = "6.0.0" +rusqlite.workspace = true url.workspace = true forge_embed.workspace = true include_dir.workspace = true diff --git a/crates/forge_main/src/lib.rs b/crates/forge_main/src/lib.rs index c5b342df7c..812c1fe549 100644 --- a/crates/forge_main/src/lib.rs +++ b/crates/forge_main/src/lib.rs @@ -9,6 +9,7 @@ mod input; mod model; mod porcelain; mod prompt; +pub mod rprompt_fast; mod sandbox; mod state; mod stream_renderer; @@ -17,7 +18,7 @@ mod title_display; mod tools_display; pub mod tracker; mod ui; -mod utils; +pub mod utils; mod vscode; mod zsh; diff --git a/crates/forge_main/src/main.rs b/crates/forge_main/src/main.rs index 3f680bdb4d..4cd11621ad 100644 --- a/crates/forge_main/src/main.rs +++ b/crates/forge_main/src/main.rs @@ -6,7 +6,7 @@ use anyhow::Result; use clap::Parser; use forge_api::ForgeAPI; use forge_domain::TitleFormat; -use forge_main::{Cli, Sandbox, TitleDisplayExt, UI, tracker}; +use forge_main::{Cli, Sandbox, TitleDisplayExt, UI, rprompt_fast, tracker, utils}; #[tokio::main] async fn main() -> Result<()> { @@ -35,6 +35,76 @@ async fn main() -> Result<()> { })); // Initialize and run the UI + + // Fast path: zsh rprompt without conversation ID - check BEFORE Cli::parse() + let args: Vec = std::env::args().collect(); + let has_conv = std::env::var("_FORGE_CONVERSATION_ID") + .ok() + .filter(|s| !s.trim().is_empty()) + .is_some(); + if args.len() >= 3 && args[1] == "zsh" && args[2] == "rprompt" && !has_conv { + println!(" %B%F{{240}}󱙺 FORGE%f%b"); + return Ok(()); + } + + // Fast path: zsh rprompt WITH conversation ID - direct SQLite query + if args.len() >= 3 + && args[1] == "zsh" + && args[2] == "rprompt" + && has_conv + && let Ok(conv_id) = std::env::var("_FORGE_CONVERSATION_ID") + { + let conv_id = conv_id.trim(); + if !conv_id.is_empty() { + // Try fast path - if it fails, fall through to normal path + if let Some(data) = rprompt_fast::fetch_rprompt_data(conv_id) { + let use_nerd_font = std::env::var("NERD_FONT") + .or_else(|_| std::env::var("USE_NERD_FONT")) + .map(|v| v == "1") + .unwrap_or(true); + + // Check if we have token count (active state) or just show inactive + if let Some(token_count) = data.token_count { + let icon = if use_nerd_font { "󱙺" } else { "" }; + let count_str = utils::humanize_number(token_count); + + // Active state: bright colors + print!(" %B%F{{15}}{} FORGE%f%b %B%F{{15}}{}%f%b", icon, count_str); + + if let Some(cost) = data.cost { + let currency = std::env::var("FORGE_CURRENCY_SYMBOL") + .unwrap_or_else(|_| "$".to_string()); + let ratio: f64 = std::env::var("FORGE_CURRENCY_CONVERSION_RATE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(1.0); + print!(" %B%F{{2}}{}{:.2}%f%b", currency, cost * ratio); + } + + if let Some(ref model) = data.model { + let model_icon = if use_nerd_font { "󰑙" } else { "" }; + print!(" %F{{134}}{}{}", model_icon, model); + } + + println!(); + return Ok(()); + } else { + // No token count - show inactive/dimmed state + let icon = if use_nerd_font { "󱙺" } else { "" }; + let model_str = data.model.as_deref().unwrap_or("forge"); + let model_icon = if use_nerd_font { "󰑙" } else { "" }; + + print!( + " %B%F{{240}}{} FORGE%f%b %F{{240}}{}{}%f", + icon, model_icon, model_str + ); + println!(); + return Ok(()); + } + } + } + } + let mut cli = Cli::parse(); // Check if there's piped input diff --git a/crates/forge_main/src/rprompt_fast.rs b/crates/forge_main/src/rprompt_fast.rs new file mode 100644 index 0000000000..d94ba194aa --- /dev/null +++ b/crates/forge_main/src/rprompt_fast.rs @@ -0,0 +1,185 @@ +//! Fast rprompt data fetcher using direct SQLite access. +//! +//! This module provides a lightweight way to fetch rprompt data (token count, +//! cost, model) directly from the SQLite database without loading the full +//! Forge infrastructure stack. + +use std::path::PathBuf; + +/// Data fetched from the database for rprompt display +#[derive(Debug, Default)] +pub struct RpromptData { + pub token_count: Option, + pub cost: Option, + pub model: Option, +} + +/// Fetches rprompt data from the SQLite database directly. +/// +/// This is a fast path that bypasses the full Forge infrastructure stack. +/// Returns None on any error (DB not found, locked, invalid ID, etc.). +pub fn fetch_rprompt_data(conversation_id: &str) -> Option { + let db_path = get_database_path()?; + let conn = rusqlite::Connection::open(&db_path).ok()?; + + let context: String = conn + .query_row( + "SELECT context FROM conversations WHERE conversation_id = ?1", + [conversation_id], + |row| row.get(0), + ) + .ok()?; + + // Use in-memory SQLite for JSON extraction + let mem_conn = rusqlite::Connection::open_in_memory().ok()?; + + let token_count = extract_token_count(&mem_conn, &context); + let cost = extract_cost(&mem_conn, &context); + let model = extract_model(&mem_conn, &context); + + Some(RpromptData { token_count, cost, model }) +} + +fn get_database_path() -> Option { + // Use current working directory, matching how forge resolves the DB path + // The DB is at .forge/forge.db relative to the project directory + let cwd = std::env::current_dir().ok()?; + Some(cwd.join(".forge").join("forge.db")) +} + +fn extract_token_count(conn: &rusqlite::Connection, context: &str) -> Option { + // Try top-level usage.total_tokens + let result: Option = conn + .query_row( + "SELECT json_extract(?1, '$.usage.total_tokens')", + [context], + |row| row.get(0), + ) + .ok(); + + if let Some(val) = result { + return parse_token_value(&val); + } + + // Fallback: last message's usage.total_tokens + let messages: Option = conn + .query_row("SELECT json_extract(?1, '$.messages')", [context], |row| { + row.get(0) + }) + .ok()?; + + let last_message: Option = conn + .query_row("SELECT json_extract(?1, '$[-1]')", [&messages], |row| { + row.get(0) + }) + .ok(); + + if let Some(msg) = last_message { + let result: Option = conn + .query_row( + "SELECT json_extract(?1, '$.usage.total_tokens')", + [&msg], + |row| row.get(0), + ) + .ok(); + if let Some(val) = result { + return parse_token_value(&val); + } + } + + None +} + +fn parse_token_value(val: &str) -> Option { + let val = val.trim(); + + if let Some(inner) = val + .strip_prefix("Actual(") + .or_else(|| val.strip_prefix("Approx(")) + && let Some(num) = inner.strip_suffix(')') + { + return num.parse().ok(); + } + + val.parse().ok() +} + +fn extract_cost(conn: &rusqlite::Connection, context: &str) -> Option { + let result: Option = conn + .query_row( + "SELECT json_extract(?1, '$.usage.cost')", + [context], + |row| row.get(0), + ) + .ok()?; + + result.and_then(|s| s.parse().ok()) +} + +fn extract_model(conn: &rusqlite::Connection, context: &str) -> Option { + let result: Option = conn + .query_row( + "SELECT json_extract(?1, '$.usage.model')", + [context], + |row| row.get(0), + ) + .ok()?; + + let result = result?; + if result.len() >= 2 && result.starts_with('"') && result.ends_with('"') { + Some(result[1..result.len() - 1].to_string()) + } else { + Some(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_mem_conn() -> Option { + rusqlite::Connection::open_in_memory().ok() + } + + #[test] + fn test_extract_token_count_actual() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"total_tokens": "Actual(1500)"}}"#; + assert_eq!(extract_token_count(&conn, context), Some(1500)); + } + + #[test] + fn test_extract_token_count_approx() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"total_tokens": "Approx(100)"}}"#; + assert_eq!(extract_token_count(&conn, context), Some(100)); + } + + #[test] + fn test_extract_token_count_raw() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"total_tokens": "2000"}}"#; + assert_eq!(extract_token_count(&conn, context), Some(2000)); + } + + #[test] + fn test_extract_cost() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"cost": "0.0123"}}"#; + assert_eq!(extract_cost(&conn, context), Some(0.0123)); + } + + #[test] + fn test_extract_model() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"model": "gpt-4"}}"#; + assert_eq!(extract_model(&conn, context), Some("gpt-4".to_string())); + } + + #[test] + fn test_extract_model_single_quote_edge_case() { + let conn = create_mem_conn().unwrap(); + let context = r#"{"usage": {"model": "\""}}"#; + assert_eq!(extract_model(&conn, context), Some("\"".to_string())); + } +}