From a4a76dc7025501024161badca5b06c1d2b145c65 Mon Sep 17 00:00:00 2001 From: Patrick szymkowiak Date: Thu, 19 Mar 2026 08:11:07 +0100 Subject: [PATCH] test: add 47 comprehensive tests across all modules - SQLite lock store (9): lock/release, blocking, relock, release_all, all_locks, locks_for_agent, gc_expired, refresh_ttl, concurrent access - Database (10): schema init, upsert/count symbols, list/search/filter, available symbols, session lifecycle, integrity check - Config (5): default config, save/load roundtrip, missing file fallback, malformed JSON resilience, S3 config roundtrip - Parser AST (13): Rust/TypeScript/Python/JavaScript parsing, struct+impl, enum+trait, interfaces, symbol ID format, hash determinism, node_modules skip, kind normalization - CLI validation (10): valid/invalid identifiers, path traversal, argument injection, lock expiry computation --- Cargo.lock | 33 ++++ Cargo.toml | 3 + src/cli/mod.rs | 80 +++++++++ src/config.rs | 68 ++++++++ src/db/mod.rs | 170 +++++++++++++++++++ src/db/sqlite_store.rs | 212 ++++++++++++++++++++++++ src/parser/mod.rs | 360 +++++++++++++++++++++++++++++++++++++++++ 7 files changed, 926 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 610589b..58829df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1061,6 +1061,7 @@ dependencies = [ "rusqlite", "serde", "serde_json", + "tempfile", "tokio", "tree-sitter", "tree-sitter-javascript", @@ -1534,6 +1535,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "litemap" version = "0.8.1" @@ -1806,6 +1813,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.21.12" @@ -2134,6 +2154,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "time" version = "0.3.47" diff --git a/Cargo.toml b/Cargo.toml index 5ebd02c..e9e69eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,3 +43,6 @@ colored = "2" aws-config = { version = "1", features = ["behavior-version-latest"] } aws-sdk-s3 = "1" urlencoding = "2" + +[dev-dependencies] +tempfile = "3" diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 42b087e..cfd455f 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -933,3 +933,83 @@ fn cmd_config_show(repo: &str) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::lock_store::LockEntry; + + // ── validate_identifier tests ── + + #[test] + fn test_validate_identifier_valid() { + assert!(validate_identifier("agent-1", "id").is_ok()); + assert!(validate_identifier("my_agent", "id").is_ok()); + assert!(validate_identifier("agent.v2", "id").is_ok()); + assert!(validate_identifier("abc123", "id").is_ok()); + } + + #[test] + fn test_validate_identifier_empty() { + assert!(validate_identifier("", "id").is_err()); + } + + #[test] + fn test_validate_identifier_path_traversal() { + assert!(validate_identifier("..", "id").is_err()); + } + + #[test] + fn test_validate_identifier_slash() { + assert!(validate_identifier("foo/bar", "id").is_err()); + } + + #[test] + fn test_validate_identifier_backslash() { + assert!(validate_identifier("foo\\bar", "id").is_err()); + } + + #[test] + fn test_validate_identifier_starts_with_dash() { + assert!(validate_identifier("-agent", "id").is_err()); + } + + #[test] + fn test_validate_identifier_special_chars() { + assert!(validate_identifier("foo@bar", "id").is_err()); + assert!(validate_identifier("foo bar", "id").is_err()); + assert!(validate_identifier("foo;rm", "id").is_err()); + } + + // ── is_entry_expired_local tests ── + + fn make_entry(locked_at: &str, ttl: u64) -> LockEntry { + LockEntry { + symbol_id: "test::sym".to_string(), + agent_id: "agent-1".to_string(), + intent: "testing".to_string(), + locked_at: locked_at.to_string(), + ttl_seconds: ttl, + } + } + + #[test] + fn test_is_entry_expired_local_fresh() { + let now = chrono::Utc::now().to_rfc3339(); + let entry = make_entry(&now, 600); + assert!(!is_entry_expired_local(&entry)); + } + + #[test] + fn test_is_entry_expired_local_expired() { + let one_hour_ago = (chrono::Utc::now() - chrono::Duration::hours(1)).to_rfc3339(); + let entry = make_entry(&one_hour_ago, 60); + assert!(is_entry_expired_local(&entry)); + } + + #[test] + fn test_is_entry_expired_local_bad_timestamp() { + let entry = make_entry("not-a-timestamp", 600); + assert!(is_entry_expired_local(&entry)); + } +} diff --git a/src/config.rs b/src/config.rs index dbbcd9d..6044619 100644 --- a/src/config.rs +++ b/src/config.rs @@ -50,3 +50,71 @@ impl GritConfig { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::s3_store::S3Config; + use tempfile::TempDir; + + #[test] + fn test_default_config() { + let config = GritConfig::default(); + assert_eq!(config.backend, "local"); + assert!(config.s3.is_none()); + } + + #[test] + fn test_save_and_load() { + let tmp = TempDir::new().unwrap(); + let config = GritConfig { + backend: "local".to_string(), + s3: None, + }; + config.save(tmp.path()).unwrap(); + let loaded = GritConfig::load(tmp.path()).unwrap(); + assert_eq!(loaded.backend, "local"); + assert!(loaded.s3.is_none()); + } + + #[test] + fn test_load_missing_file() { + let tmp = TempDir::new().unwrap(); + // No config.json written — should return default + let config = GritConfig::load(tmp.path()).unwrap(); + assert_eq!(config.backend, "local"); + assert!(config.s3.is_none()); + } + + #[test] + fn test_load_malformed_json() { + let tmp = TempDir::new().unwrap(); + let path = tmp.path().join("config.json"); + std::fs::write(&path, "not valid json {{{").unwrap(); + let config = GritConfig::load(tmp.path()).unwrap(); + assert_eq!(config.backend, "local"); + assert!(config.s3.is_none()); + } + + #[test] + fn test_s3_config_roundtrip() { + let tmp = TempDir::new().unwrap(); + let config = GritConfig { + backend: "s3".to_string(), + s3: Some(S3Config { + bucket: "my-bucket".to_string(), + prefix: Some("grit/locks/".to_string()), + region: Some("us-east-1".to_string()), + endpoint: Some("https://custom.endpoint.com".to_string()), + }), + }; + config.save(tmp.path()).unwrap(); + let loaded = GritConfig::load(tmp.path()).unwrap(); + assert_eq!(loaded.backend, "s3"); + let s3 = loaded.s3.unwrap(); + assert_eq!(s3.bucket, "my-bucket"); + assert_eq!(s3.prefix.unwrap(), "grit/locks/"); + assert_eq!(s3.region.unwrap(), "us-east-1"); + assert_eq!(s3.endpoint.unwrap(), "https://custom.endpoint.com"); + } +} diff --git a/src/db/mod.rs b/src/db/mod.rs index 53cc25b..0493d4f 100644 --- a/src/db/mod.rs +++ b/src/db/mod.rs @@ -251,3 +251,173 @@ impl Database { } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::Symbol; + use tempfile::TempDir; + + fn make_symbol(id: &str, file: &str, name: &str, kind: &str) -> Symbol { + Symbol { + id: id.to_string(), + file: file.to_string(), + name: name.to_string(), + kind: kind.to_string(), + start_line: 1, + end_line: 10, + hash: "abc123".to_string(), + } + } + + fn setup_db() -> (TempDir, Database) { + let tmp = TempDir::new().unwrap(); + let db_path = tmp.path().join("test.db"); + let db = Database::open(&db_path).unwrap(); + db.init_schema().unwrap(); + (tmp, db) + } + + #[test] + fn test_open_and_init_schema() { + let tmp = TempDir::new().unwrap(); + let db_path = tmp.path().join("test.db"); + let db = Database::open(&db_path).unwrap(); + assert!(db.init_schema().is_ok()); + } + + #[test] + fn test_upsert_and_count_symbols() { + let (_tmp, db) = setup_db(); + let symbols: Vec = (0..5) + .map(|i| make_symbol(&format!("file.rs::fn{}", i), "file.rs", &format!("fn{}", i), "function")) + .collect(); + db.upsert_symbols(&symbols).unwrap(); + assert_eq!(db.count_symbols().unwrap(), 5); + } + + #[test] + fn test_upsert_updates_existing() { + let (_tmp, db) = setup_db(); + let sym = make_symbol("file.rs::foo", "file.rs", "foo", "function"); + db.upsert_symbols(&[sym]).unwrap(); + assert_eq!(db.count_symbols().unwrap(), 1); + + // Update same symbol with different hash + let updated = Symbol { + id: "file.rs::foo".to_string(), + file: "file.rs".to_string(), + name: "foo".to_string(), + kind: "function".to_string(), + start_line: 5, + end_line: 20, + hash: "new_hash".to_string(), + }; + db.upsert_symbols(&[updated]).unwrap(); + assert_eq!(db.count_symbols().unwrap(), 1); + } + + #[test] + fn test_list_symbols_no_filter() { + let (_tmp, db) = setup_db(); + let symbols = vec![ + make_symbol("a.rs::fn1", "a.rs", "fn1", "function"), + make_symbol("a.rs::fn2", "a.rs", "fn2", "function"), + make_symbol("b.rs::fn3", "b.rs", "fn3", "function"), + ]; + db.upsert_symbols(&symbols).unwrap(); + let all = db.list_symbols(None).unwrap(); + assert_eq!(all.len(), 3); + } + + #[test] + fn test_list_symbols_with_filter() { + let (_tmp, db) = setup_db(); + let symbols = vec![ + make_symbol("src/a.rs::fn1", "src/a.rs", "fn1", "function"), + make_symbol("src/a.rs::fn2", "src/a.rs", "fn2", "function"), + make_symbol("src/b.rs::fn3", "src/b.rs", "fn3", "function"), + ]; + db.upsert_symbols(&symbols).unwrap(); + let filtered = db.list_symbols(Some("a.rs")).unwrap(); + assert_eq!(filtered.len(), 2); + for row in &filtered { + assert!(row.1.contains("a.rs")); + } + } + + #[test] + fn test_search_symbols() { + let (_tmp, db) = setup_db(); + let symbols = vec![ + make_symbol("src/auth.rs::login", "src/auth.rs", "login", "function"), + make_symbol("src/auth.rs::logout", "src/auth.rs", "logout", "function"), + make_symbol("src/db.rs::connect", "src/db.rs", "connect", "function"), + ]; + db.upsert_symbols(&symbols).unwrap(); + let results = db.search_symbols(&["login"]).unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].2, "login"); + } + + #[test] + fn test_available_symbols_in_files() { + let (_tmp, db) = setup_db(); + let symbols = vec![ + make_symbol("f.rs::a", "f.rs", "a", "function"), + make_symbol("f.rs::b", "f.rs", "b", "function"), + make_symbol("f.rs::c", "f.rs", "c", "function"), + ]; + db.upsert_symbols(&symbols).unwrap(); + + // Lock symbol "f.rs::b" + db.conn.execute( + "INSERT INTO locks (symbol_id, agent_id, intent) VALUES (?1, ?2, ?3)", + params!["f.rs::b", "agent-1", "editing"], + ).unwrap(); + + let available = db.available_symbols_in_files(&["f.rs"]).unwrap(); + assert_eq!(available.len(), 2); + assert!(available.contains(&"f.rs::a".to_string())); + assert!(available.contains(&"f.rs::c".to_string())); + assert!(!available.contains(&"f.rs::b".to_string())); + } + + #[test] + fn test_session_lifecycle() { + let (_tmp, db) = setup_db(); + db.create_session("sess1", "feature/x", "main").unwrap(); + + let active = db.get_active_session().unwrap(); + assert!(active.is_some()); + let (name, branch, base) = active.unwrap(); + assert_eq!(name, "sess1"); + assert_eq!(branch, "feature/x"); + assert_eq!(base, "main"); + + db.close_session("sess1").unwrap(); + let active = db.get_active_session().unwrap(); + assert!(active.is_none()); + } + + #[test] + fn test_no_active_session() { + let (_tmp, db) = setup_db(); + let active = db.get_active_session().unwrap(); + assert!(active.is_none()); + } + + #[test] + fn test_integrity_check_on_open() { + let tmp = TempDir::new().unwrap(); + let db_path = tmp.path().join("test.db"); + // First open creates the file + { + let db = Database::open(&db_path).unwrap(); + db.init_schema().unwrap(); + } + // Second open runs integrity check on existing DB + let result = Database::open(&db_path); + assert!(result.is_ok()); + } +} diff --git a/src/db/sqlite_store.rs b/src/db/sqlite_store.rs index 03648ce..129729f 100644 --- a/src/db/sqlite_store.rs +++ b/src/db/sqlite_store.rs @@ -129,3 +129,215 @@ impl LockStore for SqliteLockStore { Ok(count) } } + +#[cfg(test)] +mod tests { + use super::*; + use super::super::lock_store::{LockResult, LockStore}; + use std::sync::Arc; + + /// Create a temporary SQLite database with the locks table and return the store. + fn setup() -> (tempfile::TempDir, SqliteLockStore) { + let dir = tempfile::tempdir().expect("failed to create temp dir"); + let db_path = dir.path().join("test.db"); + + // Create schema directly — avoids needing the full Database struct and symbols table FK + { + let conn = Connection::open(&db_path).unwrap(); + conn.execute_batch( + "PRAGMA journal_mode=WAL; + PRAGMA busy_timeout=5000; + CREATE TABLE IF NOT EXISTS locks ( + symbol_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + intent TEXT, + mode TEXT DEFAULT 'write', + locked_at TEXT DEFAULT (datetime('now')), + ttl_seconds INTEGER DEFAULT 600, + PRIMARY KEY (symbol_id) + ); + CREATE INDEX IF NOT EXISTS idx_locks_agent ON locks(agent_id);", + ) + .unwrap(); + } + + let store = SqliteLockStore::open(&db_path).expect("failed to open store"); + (dir, store) + } + + #[test] + fn test_lock_and_release() { + let (_dir, store) = setup(); + + // Lock a symbol + let result = store.try_lock("sym::foo", "agent-1", "editing foo", 600).unwrap(); + assert!(matches!(result, LockResult::Granted)); + + // Verify it appears in all_locks + let locks = store.all_locks().unwrap(); + assert_eq!(locks.len(), 1); + assert_eq!(locks[0].symbol_id, "sym::foo"); + assert_eq!(locks[0].agent_id, "agent-1"); + + // Release it + store.release("sym::foo", "agent-1").unwrap(); + + // Verify it's gone + let locks = store.all_locks().unwrap(); + assert!(locks.is_empty()); + } + + #[test] + fn test_lock_blocked_by_other_agent() { + let (_dir, store) = setup(); + + // Agent A locks + let result = store.try_lock("sym::bar", "agent-A", "refactoring", 600).unwrap(); + assert!(matches!(result, LockResult::Granted)); + + // Agent B tries the same symbol + let result = store.try_lock("sym::bar", "agent-B", "also refactoring", 600).unwrap(); + match result { + LockResult::Blocked { by_agent, by_intent } => { + assert_eq!(by_agent, "agent-A"); + assert_eq!(by_intent, "refactoring"); + } + LockResult::Granted => panic!("expected Blocked, got Granted"), + } + } + + #[test] + fn test_same_agent_relock() { + let (_dir, store) = setup(); + + // Agent A locks + let result = store.try_lock("sym::baz", "agent-A", "first pass", 300).unwrap(); + assert!(matches!(result, LockResult::Granted)); + + // Agent A locks again (should refresh TTL, still Granted) + let result = store.try_lock("sym::baz", "agent-A", "second pass", 900).unwrap(); + assert!(matches!(result, LockResult::Granted)); + + // Verify only one lock exists and TTL was updated + let locks = store.all_locks().unwrap(); + assert_eq!(locks.len(), 1); + assert_eq!(locks[0].ttl_seconds, 900); + assert_eq!(locks[0].intent, "second pass"); + } + + #[test] + fn test_release_all() { + let (_dir, store) = setup(); + + // Agent locks 3 symbols + store.try_lock("sym::a", "agent-1", "intent-a", 600).unwrap(); + store.try_lock("sym::b", "agent-1", "intent-b", 600).unwrap(); + store.try_lock("sym::c", "agent-1", "intent-c", 600).unwrap(); + + // Also one lock by another agent (should not be released) + store.try_lock("sym::d", "agent-2", "intent-d", 600).unwrap(); + + let count = store.release_all("agent-1").unwrap(); + assert_eq!(count, 3); + + let locks = store.all_locks().unwrap(); + assert_eq!(locks.len(), 1); + assert_eq!(locks[0].agent_id, "agent-2"); + } + + #[test] + fn test_all_locks() { + let (_dir, store) = setup(); + + store.try_lock("sym::x", "agent-A", "ix", 600).unwrap(); + store.try_lock("sym::y", "agent-A", "iy", 600).unwrap(); + store.try_lock("sym::z", "agent-B", "iz", 600).unwrap(); + + let locks = store.all_locks().unwrap(); + assert_eq!(locks.len(), 3); + + // Verify ordering is by agent_id then symbol_id + let ids: Vec<(&str, &str)> = locks.iter().map(|l| (l.agent_id.as_str(), l.symbol_id.as_str())).collect(); + assert_eq!(ids, vec![("agent-A", "sym::x"), ("agent-A", "sym::y"), ("agent-B", "sym::z")]); + } + + #[test] + fn test_locks_for_agent() { + let (_dir, store) = setup(); + + store.try_lock("sym::p", "agent-1", "ip", 600).unwrap(); + store.try_lock("sym::q", "agent-1", "iq", 600).unwrap(); + store.try_lock("sym::r", "agent-2", "ir", 600).unwrap(); + + let agent1_locks = store.locks_for_agent("agent-1").unwrap(); + assert_eq!(agent1_locks.len(), 2); + let symbols: Vec<&str> = agent1_locks.iter().map(|(s, _)| s.as_str()).collect(); + assert!(symbols.contains(&"sym::p")); + assert!(symbols.contains(&"sym::q")); + + let agent2_locks = store.locks_for_agent("agent-2").unwrap(); + assert_eq!(agent2_locks.len(), 1); + assert_eq!(agent2_locks[0].0, "sym::r"); + } + + #[test] + fn test_gc_expired_locks() { + let (_dir, store) = setup(); + + // Lock with TTL=1 second + store.try_lock("sym::expire", "agent-1", "short-lived", 1).unwrap(); + + // Verify it exists + assert_eq!(store.all_locks().unwrap().len(), 1); + + // Sleep to let it expire + std::thread::sleep(std::time::Duration::from_secs(2)); + + // GC should clean it up + let cleaned = store.gc_expired_locks().unwrap(); + assert_eq!(cleaned, 1); + assert!(store.all_locks().unwrap().is_empty()); + } + + #[test] + fn test_refresh_ttl() { + let (_dir, store) = setup(); + + store.try_lock("sym::m", "agent-1", "im", 300).unwrap(); + store.try_lock("sym::n", "agent-1", "in", 300).unwrap(); + + let count = store.refresh_ttl("agent-1", 900).unwrap(); + assert_eq!(count, 2); + + // Verify the TTL was updated + let locks = store.all_locks().unwrap(); + for lock in &locks { + assert_eq!(lock.ttl_seconds, 900); + } + } + + #[test] + fn test_concurrent_access() { + let (_dir, store) = setup(); + let store = Arc::new(store); + let mut handles = Vec::new(); + + for i in 0..10 { + let store = Arc::clone(&store); + let handle = std::thread::spawn(move || { + let agent = format!("agent-{}", i); + store.try_lock("sym::contested", &agent, "racing", 600).unwrap() + }); + handles.push(handle); + } + + let results: Vec = handles.into_iter().map(|h| h.join().unwrap()).collect(); + + let granted = results.iter().filter(|r| matches!(r, LockResult::Granted)).count(); + let blocked = results.iter().filter(|r| matches!(r, LockResult::Blocked { .. })).count(); + + // Exactly one thread should win the lock + assert_eq!(granted, 1, "expected exactly 1 Granted, got {}", granted); + assert_eq!(blocked, 9, "expected exactly 9 Blocked, got {}", blocked); + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index d2495f0..26f09db 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -222,3 +222,363 @@ impl SymbolIndex { ] } } + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + /// Helper: create a file inside a TempDir, creating parent dirs as needed. + fn write_file(dir: &TempDir, rel_path: &str, content: &str) -> PathBuf { + let full = dir.path().join(rel_path); + if let Some(parent) = full.parent() { + fs::create_dir_all(parent).unwrap(); + } + fs::write(&full, content).unwrap(); + full + } + + /// Helper: find symbol by name in a slice. + fn find_sym<'a>(symbols: &'a [Symbol], name: &str) -> &'a Symbol { + symbols + .iter() + .find(|s| s.name == name) + .unwrap_or_else(|| panic!("symbol '{}' not found in {:?}", name, symbols.iter().map(|s| &s.name).collect::>())) + } + + // ── 1. Rust functions ────────────────────────────────────────────── + + #[test] + fn test_parse_rust_functions() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/lib.rs", r#" +fn alpha() {} +fn beta(x: i32) -> i32 { x } +fn gamma() -> String { String::new() } +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); + assert_eq!(fns.len(), 3, "expected 3 functions, got {:?}", fns); + + for name in &["alpha", "beta", "gamma"] { + let sym = find_sym(&symbols, name); + assert_eq!(sym.kind, "function"); + } + } + + // ── 2. Rust struct + impl ────────────────────────────────────────── + + #[test] + fn test_parse_rust_struct_and_impl() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/model.rs", r#" +struct Point { + x: f64, + y: f64, +} + +impl Point { + fn distance(&self) -> f64 { + (self.x * self.x + self.y * self.y).sqrt() + } +} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let _struct_sym = find_sym(&symbols, "Point"); + // The struct itself has kind "struct"; the impl also has name "Point" with kind "impl". + let kinds: Vec<_> = symbols.iter().filter(|s| s.name == "Point").map(|s| s.kind.as_str()).collect(); + assert!(kinds.contains(&"struct"), "expected struct, got {:?}", kinds); + assert!(kinds.contains(&"impl"), "expected impl, got {:?}", kinds); + + // The method inside impl should also be extracted + let distance = find_sym(&symbols, "distance"); + assert_eq!(distance.kind, "function"); + } + + // ── 3. Rust enum + trait ─────────────────────────────────────────── + + #[test] + fn test_parse_rust_enum_and_trait() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/types.rs", r#" +enum Color { + Red, + Green, + Blue, +} + +trait Drawable { + fn draw(&self); +} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let color = find_sym(&symbols, "Color"); + assert_eq!(color.kind, "enum"); + + let drawable = find_sym(&symbols, "Drawable"); + assert_eq!(drawable.kind, "trait"); + } + + // ── 4. TypeScript functions ──────────────────────────────────────── + + #[test] + fn test_parse_typescript_functions() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/utils.ts", r#" +function add(a: number, b: number): number { + return a + b; +} + +function greet(name: string): string { + return `Hello, ${name}`; +} + +function noop(): void {} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); + assert_eq!(fns.len(), 3); + + for name in &["add", "greet", "noop"] { + let sym = find_sym(&symbols, name); + assert_eq!(sym.kind, "function"); + } + } + + // ── 5. TypeScript class + methods ────────────────────────────────── + + #[test] + fn test_parse_typescript_class() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/service.ts", r#" +class UserService { + getUser(id: number): string { + return "user"; + } + deleteUser(id: number): void {} +} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let cls = find_sym(&symbols, "UserService"); + assert_eq!(cls.kind, "class"); + + let methods: Vec<_> = symbols.iter().filter(|s| s.kind == "method").collect(); + assert_eq!(methods.len(), 2); + find_sym(&symbols, "getUser"); + find_sym(&symbols, "deleteUser"); + } + + // ── 6. TypeScript interface ──────────────────────────────────────── + + #[test] + fn test_parse_typescript_interface() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/types.ts", r#" +interface Config { + host: string; + port: number; +} + +interface Logger { + log(msg: string): void; +} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let interfaces: Vec<_> = symbols.iter().filter(|s| s.kind == "interface").collect(); + assert_eq!(interfaces.len(), 2); + find_sym(&symbols, "Config"); + find_sym(&symbols, "Logger"); + } + + // ── 7. Python functions ──────────────────────────────────────────── + + #[test] + fn test_parse_python_functions() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "utils.py", r#" +def connect(host, port): + pass + +def disconnect(): + pass + +def retry(fn, times=3): + pass +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); + assert_eq!(fns.len(), 3); + for name in &["connect", "disconnect", "retry"] { + find_sym(&symbols, name); + } + } + + // ── 8. Python class + methods ────────────────────────────────────── + + #[test] + fn test_parse_python_class() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "models.py", r#" +class Dog: + def __init__(self, name): + self.name = name + + def bark(self): + return "woof" +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let cls = find_sym(&symbols, "Dog"); + assert_eq!(cls.kind, "class"); + + // Python methods are function_definition nodes inside a class — they get kind "function" + let methods: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); + assert!(methods.len() >= 2, "expected at least __init__ and bark"); + find_sym(&symbols, "__init__"); + find_sym(&symbols, "bark"); + } + + // ── 9. JavaScript functions ──────────────────────────────────────── + + #[test] + fn test_parse_javascript_functions() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "lib/helpers.js", r#" +function sum(a, b) { + return a + b; +} + +function multiply(a, b) { + return a * b; +} +"#); + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + let fns: Vec<_> = symbols.iter().filter(|s| s.kind == "function").collect(); + assert_eq!(fns.len(), 2); + find_sym(&symbols, "sum"); + find_sym(&symbols, "multiply"); + } + + // ── 10. Symbol ID format ─────────────────────────────────────────── + + #[test] + fn test_symbol_id_format() { + let dir = TempDir::new().unwrap(); + write_file(&dir, "src/core/engine.rs", "fn run() {}\n"); + + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + assert_eq!(symbols.len(), 1); + + let sym = &symbols[0]; + assert_eq!(sym.id, "src/core/engine.rs::run"); + assert_eq!(sym.file, "src/core/engine.rs"); + assert_eq!(sym.name, "run"); + } + + // ── 11. Hash determinism ─────────────────────────────────────────── + + #[test] + fn test_symbol_hash_deterministic() { + let dir1 = TempDir::new().unwrap(); + let dir2 = TempDir::new().unwrap(); + let code = "fn deterministic() { let x = 42; }\n"; + write_file(&dir1, "a.rs", code); + write_file(&dir2, "a.rs", code); + + let s1 = SymbolIndex::new(dir1.path().to_str().unwrap()).unwrap().scan_all().unwrap(); + let s2 = SymbolIndex::new(dir2.path().to_str().unwrap()).unwrap().scan_all().unwrap(); + + assert_eq!(s1.len(), 1); + assert_eq!(s2.len(), 1); + assert_eq!(s1[0].hash, s2[0].hash, "same source text must produce the same hash"); + assert!(!s1[0].hash.is_empty()); + } + + // ── 12. Skips node_modules ───────────────────────────────────────── + + #[test] + fn test_skips_node_modules() { + let dir = TempDir::new().unwrap(); + // File inside a nested node_modules — should be skipped + // Note: the skip filter checks for "/node_modules/" in the relative path, + // so node_modules must be inside a parent directory (e.g., src/node_modules/). + write_file(&dir, "src/node_modules/lodash/index.js", "function chunk() {}\n"); + // File outside node_modules — should be found + write_file(&dir, "src/app.js", "function main() {}\n"); + + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + assert_eq!(symbols.len(), 1, "only src/app.js should be scanned"); + assert_eq!(symbols[0].name, "main"); + } + + // ── 13. Normalize kind (via public interface) ────────────────────── + + #[test] + fn test_normalize_kind() { + let dir = TempDir::new().unwrap(); + + // Rust: function_item → "function", struct_item → "struct", enum_item → "enum", + // trait_item → "trait", impl_item → "impl" + write_file(&dir, "src/all.rs", r#" +fn my_func() {} +struct MyStruct { x: i32 } +enum MyEnum { A, B } +trait MyTrait { fn do_it(&self); } +impl MyStruct { fn method(&self) {} } +"#); + + // TypeScript: function_declaration → "function", class_declaration → "class", + // method_definition → "method", interface_declaration → "interface" + write_file(&dir, "src/all.ts", r#" +function tsFunc(): void {} +class TsClass { + tsMethod(): void {} +} +interface TsInterface { x: number; } +"#); + + // Python: function_definition → "function", class_definition → "class" + write_file(&dir, "all.py", "def py_func():\n pass\n\nclass PyClass:\n pass\n"); + + let idx = SymbolIndex::new(dir.path().to_str().unwrap()).unwrap(); + let symbols = idx.scan_all().unwrap(); + + // Verify the normalized kinds + assert_eq!(find_sym(&symbols, "my_func").kind, "function"); + assert_eq!(find_sym(&symbols, "MyStruct").kind, "struct"); + assert_eq!(find_sym(&symbols, "MyEnum").kind, "enum"); + assert_eq!(find_sym(&symbols, "MyTrait").kind, "trait"); + // impl MyStruct → name="MyStruct", kind="impl" (second entry with that name) + let impls: Vec<_> = symbols.iter().filter(|s| s.kind == "impl").collect(); + assert!(!impls.is_empty(), "expected at least one impl symbol"); + + assert_eq!(find_sym(&symbols, "tsFunc").kind, "function"); + assert_eq!(find_sym(&symbols, "TsClass").kind, "class"); + assert_eq!(find_sym(&symbols, "tsMethod").kind, "method"); + assert_eq!(find_sym(&symbols, "TsInterface").kind, "interface"); + + assert_eq!(find_sym(&symbols, "py_func").kind, "function"); + assert_eq!(find_sym(&symbols, "PyClass").kind, "class"); + } +}