Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ colored = "2"
aws-config = { version = "1", features = ["behavior-version-latest"] }
aws-sdk-s3 = "1"
urlencoding = "2"

[dev-dependencies]
tempfile = "3"
80 changes: 80 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,3 +933,83 @@ fn cmd_config_show(repo: &str) -> Result<()> {

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::db::lock_store::LockEntry;

// ── validate_identifier tests ──

#[test]
fn test_validate_identifier_valid() {
assert!(validate_identifier("agent-1", "id").is_ok());
assert!(validate_identifier("my_agent", "id").is_ok());
assert!(validate_identifier("agent.v2", "id").is_ok());
assert!(validate_identifier("abc123", "id").is_ok());
}

#[test]
fn test_validate_identifier_empty() {
assert!(validate_identifier("", "id").is_err());
}

#[test]
fn test_validate_identifier_path_traversal() {
assert!(validate_identifier("..", "id").is_err());
}

#[test]
fn test_validate_identifier_slash() {
assert!(validate_identifier("foo/bar", "id").is_err());
}

#[test]
fn test_validate_identifier_backslash() {
assert!(validate_identifier("foo\\bar", "id").is_err());
}

#[test]
fn test_validate_identifier_starts_with_dash() {
assert!(validate_identifier("-agent", "id").is_err());
}

#[test]
fn test_validate_identifier_special_chars() {
assert!(validate_identifier("foo@bar", "id").is_err());
assert!(validate_identifier("foo bar", "id").is_err());
assert!(validate_identifier("foo;rm", "id").is_err());
}

// ── is_entry_expired_local tests ──

fn make_entry(locked_at: &str, ttl: u64) -> LockEntry {
LockEntry {
symbol_id: "test::sym".to_string(),
agent_id: "agent-1".to_string(),
intent: "testing".to_string(),
locked_at: locked_at.to_string(),
ttl_seconds: ttl,
}
}

#[test]
fn test_is_entry_expired_local_fresh() {
let now = chrono::Utc::now().to_rfc3339();
let entry = make_entry(&now, 600);
assert!(!is_entry_expired_local(&entry));
}

#[test]
fn test_is_entry_expired_local_expired() {
let one_hour_ago = (chrono::Utc::now() - chrono::Duration::hours(1)).to_rfc3339();
let entry = make_entry(&one_hour_ago, 60);
assert!(is_entry_expired_local(&entry));
}

#[test]
fn test_is_entry_expired_local_bad_timestamp() {
let entry = make_entry("not-a-timestamp", 600);
assert!(is_entry_expired_local(&entry));
}
}
68 changes: 68 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,71 @@ impl GritConfig {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::db::s3_store::S3Config;
use tempfile::TempDir;

#[test]
fn test_default_config() {
let config = GritConfig::default();
assert_eq!(config.backend, "local");
assert!(config.s3.is_none());
}

#[test]
fn test_save_and_load() {
let tmp = TempDir::new().unwrap();
let config = GritConfig {
backend: "local".to_string(),
s3: None,
};
config.save(tmp.path()).unwrap();
let loaded = GritConfig::load(tmp.path()).unwrap();
assert_eq!(loaded.backend, "local");
assert!(loaded.s3.is_none());
}

#[test]
fn test_load_missing_file() {
let tmp = TempDir::new().unwrap();
// No config.json written — should return default
let config = GritConfig::load(tmp.path()).unwrap();
assert_eq!(config.backend, "local");
assert!(config.s3.is_none());
}

#[test]
fn test_load_malformed_json() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("config.json");
std::fs::write(&path, "not valid json {{{").unwrap();
let config = GritConfig::load(tmp.path()).unwrap();
assert_eq!(config.backend, "local");
assert!(config.s3.is_none());
}

#[test]
fn test_s3_config_roundtrip() {
let tmp = TempDir::new().unwrap();
let config = GritConfig {
backend: "s3".to_string(),
s3: Some(S3Config {
bucket: "my-bucket".to_string(),
prefix: Some("grit/locks/".to_string()),
region: Some("us-east-1".to_string()),
endpoint: Some("https://custom.endpoint.com".to_string()),
}),
};
config.save(tmp.path()).unwrap();
let loaded = GritConfig::load(tmp.path()).unwrap();
assert_eq!(loaded.backend, "s3");
let s3 = loaded.s3.unwrap();
assert_eq!(s3.bucket, "my-bucket");
assert_eq!(s3.prefix.unwrap(), "grit/locks/");
assert_eq!(s3.region.unwrap(), "us-east-1");
assert_eq!(s3.endpoint.unwrap(), "https://custom.endpoint.com");
}
}
170 changes: 170 additions & 0 deletions src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,173 @@ impl Database {
}

}

#[cfg(test)]
mod tests {
use super::*;
use crate::parser::Symbol;
use tempfile::TempDir;

fn make_symbol(id: &str, file: &str, name: &str, kind: &str) -> Symbol {
Symbol {
id: id.to_string(),
file: file.to_string(),
name: name.to_string(),
kind: kind.to_string(),
start_line: 1,
end_line: 10,
hash: "abc123".to_string(),
}
}

fn setup_db() -> (TempDir, Database) {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.db");
let db = Database::open(&db_path).unwrap();
db.init_schema().unwrap();
(tmp, db)
}

#[test]
fn test_open_and_init_schema() {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.db");
let db = Database::open(&db_path).unwrap();
assert!(db.init_schema().is_ok());
}

#[test]
fn test_upsert_and_count_symbols() {
let (_tmp, db) = setup_db();
let symbols: Vec<Symbol> = (0..5)
.map(|i| make_symbol(&format!("file.rs::fn{}", i), "file.rs", &format!("fn{}", i), "function"))
.collect();
db.upsert_symbols(&symbols).unwrap();
assert_eq!(db.count_symbols().unwrap(), 5);
}

#[test]
fn test_upsert_updates_existing() {
let (_tmp, db) = setup_db();
let sym = make_symbol("file.rs::foo", "file.rs", "foo", "function");
db.upsert_symbols(&[sym]).unwrap();
assert_eq!(db.count_symbols().unwrap(), 1);

// Update same symbol with different hash
let updated = Symbol {
id: "file.rs::foo".to_string(),
file: "file.rs".to_string(),
name: "foo".to_string(),
kind: "function".to_string(),
start_line: 5,
end_line: 20,
hash: "new_hash".to_string(),
};
db.upsert_symbols(&[updated]).unwrap();
assert_eq!(db.count_symbols().unwrap(), 1);
}

#[test]
fn test_list_symbols_no_filter() {
let (_tmp, db) = setup_db();
let symbols = vec![
make_symbol("a.rs::fn1", "a.rs", "fn1", "function"),
make_symbol("a.rs::fn2", "a.rs", "fn2", "function"),
make_symbol("b.rs::fn3", "b.rs", "fn3", "function"),
];
db.upsert_symbols(&symbols).unwrap();
let all = db.list_symbols(None).unwrap();
assert_eq!(all.len(), 3);
}

#[test]
fn test_list_symbols_with_filter() {
let (_tmp, db) = setup_db();
let symbols = vec![
make_symbol("src/a.rs::fn1", "src/a.rs", "fn1", "function"),
make_symbol("src/a.rs::fn2", "src/a.rs", "fn2", "function"),
make_symbol("src/b.rs::fn3", "src/b.rs", "fn3", "function"),
];
db.upsert_symbols(&symbols).unwrap();
let filtered = db.list_symbols(Some("a.rs")).unwrap();
assert_eq!(filtered.len(), 2);
for row in &filtered {
assert!(row.1.contains("a.rs"));
}
}

#[test]
fn test_search_symbols() {
let (_tmp, db) = setup_db();
let symbols = vec![
make_symbol("src/auth.rs::login", "src/auth.rs", "login", "function"),
make_symbol("src/auth.rs::logout", "src/auth.rs", "logout", "function"),
make_symbol("src/db.rs::connect", "src/db.rs", "connect", "function"),
];
db.upsert_symbols(&symbols).unwrap();
let results = db.search_symbols(&["login"]).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].2, "login");
}

#[test]
fn test_available_symbols_in_files() {
let (_tmp, db) = setup_db();
let symbols = vec![
make_symbol("f.rs::a", "f.rs", "a", "function"),
make_symbol("f.rs::b", "f.rs", "b", "function"),
make_symbol("f.rs::c", "f.rs", "c", "function"),
];
db.upsert_symbols(&symbols).unwrap();

// Lock symbol "f.rs::b"
db.conn.execute(
"INSERT INTO locks (symbol_id, agent_id, intent) VALUES (?1, ?2, ?3)",
params!["f.rs::b", "agent-1", "editing"],
).unwrap();

let available = db.available_symbols_in_files(&["f.rs"]).unwrap();
assert_eq!(available.len(), 2);
assert!(available.contains(&"f.rs::a".to_string()));
assert!(available.contains(&"f.rs::c".to_string()));
assert!(!available.contains(&"f.rs::b".to_string()));
}

#[test]
fn test_session_lifecycle() {
let (_tmp, db) = setup_db();
db.create_session("sess1", "feature/x", "main").unwrap();

let active = db.get_active_session().unwrap();
assert!(active.is_some());
let (name, branch, base) = active.unwrap();
assert_eq!(name, "sess1");
assert_eq!(branch, "feature/x");
assert_eq!(base, "main");

db.close_session("sess1").unwrap();
let active = db.get_active_session().unwrap();
assert!(active.is_none());
}

#[test]
fn test_no_active_session() {
let (_tmp, db) = setup_db();
let active = db.get_active_session().unwrap();
assert!(active.is_none());
}

#[test]
fn test_integrity_check_on_open() {
let tmp = TempDir::new().unwrap();
let db_path = tmp.path().join("test.db");
// First open creates the file
{
let db = Database::open(&db_path).unwrap();
db.init_schema().unwrap();
}
// Second open runs integrity check on existing DB
let result = Database::open(&db_path);
assert!(result.is_ok());
}
}
Loading
Loading