Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 259 additions & 12 deletions src/hook_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,30 +152,170 @@ pub fn run_gemini() -> Result<()> {

let tool_name = json.get("tool_name").and_then(|v| v.as_str()).unwrap_or("");

if tool_name != "run_shell_command" {
match tool_name {
"run_shell_command" => {
let cmd = json
.pointer("/tool_input/command")
.and_then(|v| v.as_str())
.unwrap_or("");

if cmd.is_empty() {
print_allow();
return Ok(());
}

match rewrite_command(cmd, &[]) {
Some(rewritten) => print_rewrite(&rewritten),
None => print_allow(),
}
}
"read_file" => {
let file_path = json
.pointer("/tool_input/file_path")
.and_then(|v| v.as_str())
.unwrap_or("");

if file_path.is_empty() {
print_allow();
return Ok(());
}

// Extract optional line range so we can pre-apply it before filtering,
// then override both fields in the response to prevent Gemini's merge
// from re-applying the original (now wrong) line numbers to the filtered file.
let start_line = json
.pointer("/tool_input/start_line")
.and_then(|v| v.as_u64())
.map(|n| n as usize);
let end_line = json
.pointer("/tool_input/end_line")
.and_then(|v| v.as_u64())
.map(|n| n as usize);

handle_gemini_read_file(file_path, start_line, end_line)?;
}
_ => print_allow(),
}

Ok(())
}

/// Intercept Gemini CLI's `read_file` tool, filter the content, write to a temp file,
/// and redirect Gemini to read the filtered version instead.
///
/// `start_line`/`end_line` (1-based, inclusive) are pre-applied before filtering so the
/// filter sees only the relevant slice. Both are then overridden in the hook response to
/// prevent Gemini's merge from re-applying the original (now wrong) numbers to the temp file.
fn handle_gemini_read_file(
file_path: &str,
start_line: Option<usize>,
end_line: Option<usize>,
) -> Result<()> {
use crate::filter::{self, FilterLevel, Language};
use std::fs;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};

let path = Path::new(file_path);

// Read file — pass through silently for unreadable/binary files
let content = match fs::read_to_string(path) {
Ok(c) => c,
Err(_) => {
print_allow();
return Ok(());
}
};

// Pre-apply the requested line range (1-based, inclusive) before filtering,
// so the filter sees only the relevant slice and line numbers stay meaningful.
let slice: String = match (start_line, end_line) {
(Some(start), Some(end)) => content
.lines()
.skip(start.saturating_sub(1))
.take(end.saturating_sub(start.saturating_sub(1)))
.collect::<Vec<_>>()
.join("\n"),
(Some(start), None) => content
.lines()
.skip(start.saturating_sub(1))
.collect::<Vec<_>>()
.join("\n"),
(None, Some(end)) => content.lines().take(end).collect::<Vec<_>>().join("\n"),
(None, None) => content,
};

// Not worth filtering very small slices
if slice.len() < 500 {
print_allow();
return Ok(());
}

let cmd = json
.pointer("/tool_input/command")
.and_then(|v| v.as_str())
.unwrap_or("");
// Detect language from extension and apply minimal filter
let lang = path
.extension()
.and_then(|e| e.to_str())
.map(Language::from_extension)
.unwrap_or(Language::Unknown);

let filtered = filter::get_filter(FilterLevel::Minimal).filter(&slice, &lang);

if cmd.is_empty() {
// Skip if savings are negligible (<10% by character count)
let savings_pct = 100.0 - (filtered.len() as f64 / slice.len() as f64 * 100.0);
if savings_pct < 10.0 {
print_allow();
return Ok(());
}

// Delegate to the single source of truth for command rewriting
match rewrite_command(cmd, &[]) {
Some(rewritten) => print_rewrite(&rewritten),
None => print_allow(),
}
// Write filtered content to ~/.local/share/rtk/filtered/
let filtered_dir = dirs::data_local_dir()
.unwrap_or_else(|| std::path::PathBuf::from("/tmp"))
.join("rtk")
.join("filtered");

fs::create_dir_all(&filtered_dir).context("Failed to create rtk filtered dir")?;

// Rotate: keep only the last 20 filtered files
cleanup_filtered_files(&filtered_dir, 20);

let ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis();
let stem = path.file_name().and_then(|n| n.to_str()).unwrap_or("file");
// Timestamp-first so lexicographic sort in cleanup_filtered_files is chronological.
let temp_path = filtered_dir.join(format!("{ts}-{stem}.rtk"));

fs::write(&temp_path, &filtered).context("Failed to write filtered file")?;

let filtered_line_count = filtered.lines().count();
print_rewrite_read_file(&temp_path.to_string_lossy(), filtered_line_count);
Ok(())
}

/// Rotate filtered files: keep only the newest `max_files`, delete the rest.
fn cleanup_filtered_files(dir: &std::path::Path, max_files: usize) {
let mut entries: Vec<_> = std::fs::read_dir(dir)
.ok()
.into_iter()
.flatten()
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|ext| ext == "rtk"))
.collect();

if entries.len() <= max_files {
return;
}

// Filenames start with the original stem then a millisecond timestamp — sort is chronological
entries.sort_by_key(|e| e.file_name());

let to_remove = entries.len() - max_files;
for entry in entries.iter().take(to_remove) {
let _ = std::fs::remove_file(entry.path());
}
}

fn print_allow() {
println!(r#"{{"decision":"allow"}}"#);
}
Expand All @@ -189,7 +329,23 @@ fn print_rewrite(cmd: &str) {
}
}
});
println!("{}", output);
println!("{output}");
}

fn print_rewrite_read_file(filtered_path: &str, line_count: usize) {
let output = serde_json::json!({
"decision": "allow",
"hookSpecificOutput": {
"tool_input": {
"file_path": filtered_path,
// Override start/end so Gemini's merge doesn't re-apply the original
// (now wrong) line numbers to the filtered temp file.
"start_line": 1,
"end_line": line_count
}
}
});
println!("{output}");
}

#[cfg(test)]
Expand Down Expand Up @@ -330,4 +486,95 @@ mod tests {
Some("RUST_LOG=debug rtk cargo test".into())
);
}

// --- read_file hook ---

#[test]
fn test_print_rewrite_read_file_format() {
// Verify the JSON output matches what Gemini CLI expects for read_file rewrites,
// including start_line/end_line overrides to neutralise the original line range.
let output = serde_json::json!({
"decision": "allow",
"hookSpecificOutput": {
"tool_input": {
"file_path": "/tmp/filtered.rtk",
"start_line": 1,
"end_line": 42
}
}
});
let parsed: Value = serde_json::from_str(&output.to_string()).unwrap();
assert_eq!(parsed["decision"], "allow");
assert_eq!(
parsed["hookSpecificOutput"]["tool_input"]["file_path"],
"/tmp/filtered.rtk"
);
assert_eq!(parsed["hookSpecificOutput"]["tool_input"]["start_line"], 1);
assert_eq!(parsed["hookSpecificOutput"]["tool_input"]["end_line"], 42);
}

#[test]
fn test_handle_gemini_read_file_missing_path_allows() {
// Empty file_path in input → print_allow (no crash)
let input = json!({
"tool_name": "read_file",
"tool_input": { "file_path": "" }
});
// Verify file_path extraction returns empty
let file_path = input
.pointer("/tool_input/file_path")
.and_then(|v| v.as_str())
.unwrap_or("");
assert!(file_path.is_empty());
}

#[test]
fn test_handle_gemini_read_file_nonexistent_allows() {
// Non-existent file → handle_gemini_read_file falls back to allow (no panic)
let result = handle_gemini_read_file("/tmp/rtk-nonexistent-file-xyz-12345.txt", None, None);
assert!(result.is_ok());
}

#[test]
fn test_handle_gemini_read_file_small_file_allows() {
use std::io::Write;
use tempfile::NamedTempFile;

let mut f = NamedTempFile::with_suffix(".rs").expect("temp file");
writeln!(f, "fn main() {{}}").expect("write");

// File well under 500 chars → should allow through without filtering
let result = handle_gemini_read_file(f.path().to_str().unwrap(), None, None);
assert!(result.is_ok());
}

#[test]
fn test_cleanup_filtered_files_under_limit() {
use std::fs;
use tempfile::TempDir;

let dir = TempDir::new().expect("temp dir");
// Create 5 .rtk files — under the 20-file limit, none should be deleted
for i in 0..5 {
fs::write(dir.path().join(format!("file-{i}.rtk")), "x").expect("write");
}
cleanup_filtered_files(dir.path(), 20);
let remaining = fs::read_dir(dir.path()).unwrap().count();
assert_eq!(remaining, 5);
}

#[test]
fn test_cleanup_filtered_files_over_limit() {
use std::fs;
use tempfile::TempDir;

let dir = TempDir::new().expect("temp dir");
// Create 25 .rtk files — 5 oldest should be deleted to stay at 20
for i in 0..25u64 {
fs::write(dir.path().join(format!("file-{i:05}.rtk")), "x").expect("write");
}
cleanup_filtered_files(dir.path(), 20);
let remaining = fs::read_dir(dir.path()).unwrap().count();
assert_eq!(remaining, 20);
}
}