diff --git a/agent/src/bench_tools.rs b/agent/src/bench_tools.rs index 543e99e..3d5d306 100644 --- a/agent/src/bench_tools.rs +++ b/agent/src/bench_tools.rs @@ -1,4 +1,4 @@ -// Benchmark, profiling, and correctness test tools for the Agent framework. +// Benchmark, profiling, and correctness test tools for the Agent framework. use crate::tools::{BenchmarkComparison, BenchmarkResult, FunctionProfile, ToolResult}; use serde::Deserialize; @@ -22,12 +22,34 @@ const BENCHMARK_TIMEOUT_SECS: u64 = 600; /// Default recall threshold for correctness tests. const RECALL_THRESHOLD: f64 = 0.95; + +fn anti_cheat_default_passed() -> bool { + true +} + +#[derive(Debug, Clone, Deserialize)] +struct AntiCheatOutput { + #[serde(default = "anti_cheat_default_passed")] + passed: bool, + #[serde(default)] + message: String, +} + +impl Default for AntiCheatOutput { + fn default() -> Self { + Self { + passed: true, + message: String::new(), + } + } +} + /// Wrapper for the benchmark binary's JSON output: `{"benchmark": ..., "anti_cheat": ...}`. #[derive(Debug, Deserialize)] struct BenchmarkOutput { benchmark: BenchmarkResult, - #[allow(dead_code)] - anti_cheat: serde_json::Value, + #[serde(default)] + anti_cheat: AntiCheatOutput, } /// Timeout (seconds) for the server to become ready after startup. @@ -51,7 +73,8 @@ fn next_round_number(dir: &Path, prefix: &str) -> u32 { let name = entry.file_name().to_string_lossy().to_string(); // Parse "prefix_NNN.ext" �?NNN if name.starts_with(prefix) { - if let Some(num_part) = name.strip_prefix(prefix).and_then(|s| s.split('.').next()) { + if let Some(num_part) = name.strip_prefix(prefix).and_then(|s| s.split('.').next()) + { if let Ok(n) = num_part.parse::() { max = max.max(n); } @@ -109,6 +132,14 @@ fn build_comparison(prev: &BenchmarkResult, curr: &BenchmarkResult) -> Benchmark } } +fn apply_anti_cheat_guard(benchmark: &mut BenchmarkResult, anti_cheat: &AntiCheatOutput) { + if anti_cheat.passed { + return; + } + benchmark.qps = 0.0; + benchmark.recall_passed = false; +} + /// Save profiling results (flamegraph + report) to profiling/ with round numbers. /// Returns (round_number, flamegraph_path, report_path). fn save_profiling_results( @@ -133,7 +164,10 @@ fn save_profiling_results( let report_path = dir.join(format!("report_{:03}.txt", round)); let mut report = String::new(); report.push_str(&format!("Profiling Report #{:03}\n", round)); - report.push_str(&format!("Date: {}\n\n", chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC"))); + report.push_str(&format!( + "Date: {}\n\n", + chrono::Utc::now().format("%Y-%m-%d %H:%M:%S UTC") + )); report.push_str("Top Functions:\n"); for f in top_functions { report.push_str(&format!(" {:>6.2}% {}\n", f.percentage, f.function)); @@ -245,11 +279,7 @@ async fn try_cargo_build_inner(work_dir: &Path, profiling: bool) -> Result<(), S cmd.env("CARGO_PROFILE_RELEASE_CODEGEN_UNITS", "16"); } - let result = timeout( - Duration::from_secs(BUILD_TIMEOUT_SECS), - cmd.output(), - ) - .await; + let result = timeout(Duration::from_secs(BUILD_TIMEOUT_SECS), cmd.output()).await; match result { Ok(Ok(output)) => { @@ -290,7 +320,11 @@ pub async fn build_project_tool(work_dir: &Path) -> ToolResult { /// /// Launches `target/release/` in `work_dir` with the `PORT` environment /// variable set to the given port. -async fn start_server(work_dir: &Path, port: u16, cpu_cores: Option<&str>) -> Result { +async fn start_server( + work_dir: &Path, + port: u16, + cpu_cores: Option<&str>, +) -> Result { // Dynamically detect binary name from Cargo.toml let binary_name = detect_binary_name(work_dir)?; let binary = work_dir.join(format!("target/release/{}", binary_name)); @@ -310,13 +344,9 @@ async fn start_server(work_dir: &Path, port: u16, cpu_cores: Option<&str>) -> Re ) })?; - let abs_work_dir = work_dir.canonicalize().map_err(|e| { - format!( - "Failed to resolve work_dir '{}': {}", - work_dir.display(), - e - ) - })?; + let abs_work_dir = work_dir + .canonicalize() + .map_err(|e| format!("Failed to resolve work_dir '{}': {}", work_dir.display(), e))?; // Use taskset to pin the server to specific CPU cores. // This prevents the model's code from consuming all CPUs on the machine. @@ -467,7 +497,6 @@ fn find_base_vectors(data_dir: &Path, work_dir: &Path) -> Result(&stdout) { Ok(output) => { let mut bench = output.benchmark; + if !output.anti_cheat.passed { + eprintln!( + "[benchmark] Anti-cheat failed. Invalidating score (QPS=0). Detail: {}", + if output.anti_cheat.message.is_empty() { + "SUSPICIOUS benchmark output" + } else { + output.anti_cheat.message.as_str() + } + ); + } + apply_anti_cheat_guard(&mut bench, &output.anti_cheat); // Add comparison with previous run if let Some(prev) = load_previous_benchmark(work_dir) { bench.comparison = Some(build_comparison(&prev, &bench)); @@ -635,7 +677,6 @@ pub async fn run_benchmark( } } - /// Run performance profiling on the skeleton server process. /// /// **Unlike `run_benchmark` and `run_correctness_test`, this function does NOT @@ -646,7 +687,11 @@ pub async fn run_benchmark( /// Manages the full server lifecycle: kill leftover processes on the port, /// build the project, start the server, wait for readiness, run `perf record`, /// generate flamegraph, extract top functions, and finally kill the server. -pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Option) -> ToolResult { +pub async fn run_profiling( + work_dir: &Path, + config: &BenchConfig, + _duration: Option, +) -> ToolResult { let perf_data = work_dir.join("perf.data"); let flamegraph_svg = work_dir.join("flamegraph.svg"); @@ -675,7 +720,9 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt }; // 4. Wait for server readiness - if let Err(e) = wait_for_server_ready(port, SERVER_READY_TIMEOUT_SECS, SERVER_POLL_INTERVAL_MS).await { + if let Err(e) = + wait_for_server_ready(port, SERVER_READY_TIMEOUT_SECS, SERVER_POLL_INTERVAL_MS).await + { kill_server(&mut child).await; return ToolResult::Error { message: format!("Server not ready: {}", e), @@ -698,7 +745,9 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt Ok(p) => p, Err(e) => { kill_server(&mut child).await; - return ToolResult::Error { message: format!("Profiling needs real data: {}", e) }; + return ToolResult::Error { + message: format!("Profiling needs real data: {}", e), + }; } }; let query_vectors = config.data_dir.join("query_vectors.json"); @@ -730,10 +779,13 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt let mut perf_child = match Command::new("perf") .args([ "record", - "-F", "99", - "-p", &pid.to_string(), + "-F", + "99", + "-p", + &pid.to_string(), "-g", - "-o", perf_data.to_str().unwrap_or("perf.data"), + "-o", + perf_data.to_str().unwrap_or("perf.data"), ]) .current_dir(work_dir) .stdout(std::process::Stdio::piped()) @@ -753,13 +805,20 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt let bench_result = timeout( Duration::from_secs(BENCHMARK_TIMEOUT_SECS), Command::new(benchmark_bin.to_str().unwrap_or_default()) - .arg("--server-url").arg(format!("http://127.0.0.1:{}", port)) - .arg("--concurrency").arg("4") - .arg("--warmup").arg("100") - .arg("--max-queries").arg("1000") - .arg("--base-vectors").arg(base_vectors.to_str().unwrap_or_default()) - .arg("--query-vectors").arg(query_vectors.to_str().unwrap_or_default()) - .arg("--ground-truth").arg(ground_truth.to_str().unwrap_or_default()) + .arg("--server-url") + .arg(format!("http://127.0.0.1:{}", port)) + .arg("--concurrency") + .arg("4") + .arg("--warmup") + .arg("100") + .arg("--max-queries") + .arg("1000") + .arg("--base-vectors") + .arg(base_vectors.to_str().unwrap_or_default()) + .arg("--query-vectors") + .arg(query_vectors.to_str().unwrap_or_default()) + .arg("--ground-truth") + .arg(ground_truth.to_str().unwrap_or_default()) .current_dir(work_dir) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::piped()) @@ -772,14 +831,18 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt { if let Some(perf_pid) = perf_child.id() { // SIGINT (2) tells perf to flush and exit cleanly - unsafe { libc::kill(perf_pid as i32, libc::SIGINT); } + unsafe { + libc::kill(perf_pid as i32, libc::SIGINT); + } } } // Wait for perf to finish writing let perf_wait = timeout(Duration::from_secs(10), perf_child.wait()).await; match perf_wait { Ok(Ok(_)) => {} // perf exited - _ => { let _ = perf_child.kill().await; } // force kill if stuck + _ => { + let _ = perf_child.kill().await; + } // force kill if stuck } // Log benchmark result (informational, not the main output) @@ -789,13 +852,17 @@ pub async fn run_profiling(work_dir: &Path, config: &BenchConfig, _duration: Opt eprintln!("[profiling] Benchmark client completed successfully during profiling."); } else { let stderr = String::from_utf8_lossy(&output.stderr); - eprintln!("[profiling] Benchmark client exited with code {}: {}", + eprintln!( + "[profiling] Benchmark client exited with code {}: {}", output.status.code().unwrap_or(-1), - &stderr[..stderr.len().min(500)]); + &stderr[..stderr.len().min(500)] + ); } } Ok(Err(e)) => eprintln!("[profiling] Benchmark client failed to execute: {}", e), - Err(_) => eprintln!("[profiling] Benchmark client timed out (profiling data should still be valid)."), + Err(_) => eprintln!( + "[profiling] Benchmark client timed out (profiling data should still be valid)." + ), } // Check perf.data was produced @@ -918,12 +985,15 @@ fn parse_perf_report(report: &str) -> Vec { } // Sort by percentage descending, take top 10 - functions.sort_by(|a, b| b.percentage.partial_cmp(&a.percentage).unwrap_or(std::cmp::Ordering::Equal)); + functions.sort_by(|a, b| { + b.percentage + .partial_cmp(&a.percentage) + .unwrap_or(std::cmp::Ordering::Equal) + }); functions.truncate(10); functions } - /// Run a correctness test by executing the benchmark with a small query subset. /// /// Runs the benchmark client in a lightweight mode and checks whether the recall @@ -960,7 +1030,9 @@ pub async fn run_correctness_test(work_dir: &Path, config: &BenchConfig) -> Tool }; // 4. Wait for server readiness - if let Err(e) = wait_for_server_ready(port, SERVER_READY_TIMEOUT_SECS, SERVER_POLL_INTERVAL_MS).await { + if let Err(e) = + wait_for_server_ready(port, SERVER_READY_TIMEOUT_SECS, SERVER_POLL_INTERVAL_MS).await + { kill_server(&mut child).await; return ToolResult::Error { message: format!("Server not ready: {}", e), @@ -1280,6 +1352,61 @@ mod tests { assert!((output.benchmark.qps - 1500.5).abs() < f64::EPSILON); assert_eq!(output.benchmark.total_queries, 10000); assert!(output.benchmark.recall_passed); + assert!(output.anti_cheat.passed); + } + + #[test] + fn test_apply_anti_cheat_guard_invalidates_score() { + let mut benchmark = BenchmarkResult { + qps: 1500.5, + total_queries: 10000, + duration_secs: 6.66, + avg_latency_ms: 2.5, + p50_latency_ms: 2.0, + p95_latency_ms: 5.0, + p99_latency_ms: 10.0, + recall: 0.98, + recall_threshold: 0.95, + recall_passed: true, + concurrency: 4, + comparison: None, + }; + let anti_cheat = AntiCheatOutput { + passed: false, + message: "SUSPICIOUS".to_string(), + }; + + apply_anti_cheat_guard(&mut benchmark, &anti_cheat); + + assert_eq!(benchmark.qps, 0.0); + assert!(!benchmark.recall_passed); + } + + #[test] + fn test_apply_anti_cheat_guard_keeps_clean_result() { + let mut benchmark = BenchmarkResult { + qps: 1500.5, + total_queries: 10000, + duration_secs: 6.66, + avg_latency_ms: 2.5, + p50_latency_ms: 2.0, + p95_latency_ms: 5.0, + p99_latency_ms: 10.0, + recall: 0.98, + recall_threshold: 0.95, + recall_passed: true, + concurrency: 4, + comparison: None, + }; + let anti_cheat = AntiCheatOutput { + passed: true, + message: "OK".to_string(), + }; + + apply_anti_cheat_guard(&mut benchmark, &anti_cheat); + + assert_eq!(benchmark.qps, 1500.5); + assert!(benchmark.recall_passed); } // ─── detect_binary_name tests ─────────────────────────────────────────── @@ -1550,10 +1677,7 @@ fn main() { // ─── helper ────────────────────────────────────────────────────────────── fn tempdir() -> std::path::PathBuf { - let dir = std::env::temp_dir().join(format!( - "bench_tools_test_{}", - uuid::Uuid::new_v4() - )); + let dir = std::env::temp_dir().join(format!("bench_tools_test_{}", uuid::Uuid::new_v4())); std::fs::create_dir_all(&dir).unwrap(); dir } diff --git a/agent/src/main.rs b/agent/src/main.rs index 3cb2ffd..9ad3a92 100644 --- a/agent/src/main.rs +++ b/agent/src/main.rs @@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use std::time::Instant; -use state::AgentState; use logger::AgentLogger; +use state::AgentState; use tools::{dispatch_tool_call, get_tool_definitions, BenchmarkResult, ToolCall, ToolResult}; /// Vector DB Agent - Tool Call Agent for LLM evaluation @@ -157,7 +157,9 @@ struct ChatResponseMessage { const MAX_RETRIES: u32 = 5; /// Build extra body fields based on thinking mode setting. -fn build_thinking_extra(thinking_mode: &str) -> std::collections::HashMap { +fn build_thinking_extra( + thinking_mode: &str, +) -> std::collections::HashMap { let mut extra = std::collections::HashMap::new(); match thinking_mode { "true" | "openai" => { @@ -275,7 +277,10 @@ async fn call_llm( if attempt < MAX_RETRIES - 1 { continue; } - return Err(format!("LLM API request failed after {} retries: {}", MAX_RETRIES, e)); + return Err(format!( + "LLM API request failed after {} retries: {}", + MAX_RETRIES, e + )); } } } @@ -369,7 +374,10 @@ fn build_system_message(content: &str) -> ChatMessage { } /// Build an assistant message with tool calls (from LLM response). -fn build_assistant_tool_calls_message(tool_calls: Vec, reasoning_content: Option) -> ChatMessage { +fn build_assistant_tool_calls_message( + tool_calls: Vec, + reasoning_content: Option, +) -> ChatMessage { ChatMessage { role: "assistant".to_string(), content: None, @@ -380,7 +388,10 @@ fn build_assistant_tool_calls_message(tool_calls: Vec, reasonin } /// Build an assistant message with text content (no tool calls). -fn build_assistant_content_message(content: &str, reasoning_content: Option) -> ChatMessage { +fn build_assistant_content_message( + content: &str, + reasoning_content: Option, +) -> ChatMessage { ChatMessage { role: "assistant".to_string(), content: Some(content.to_string()), @@ -415,7 +426,10 @@ fn save_eval_log(work_dir: &Path, state: &AgentState) { }); let log_path = work_dir.join("eval_log.json"); - match std::fs::write(&log_path, serde_json::to_string_pretty(&log).unwrap_or_default()) { + match std::fs::write( + &log_path, + serde_json::to_string_pretty(&log).unwrap_or_default(), + ) { Ok(()) => eprintln!("[agent] Eval log saved to {}", log_path.display()), Err(e) => eprintln!("[agent] Failed to save eval log: {}", e), } @@ -444,11 +458,7 @@ struct SessionContext { } /// Save session context to `/session_context.json`. -fn save_session_context( - work_dir: &Path, - messages: &[ChatMessage], - state: &AgentState, -) { +fn save_session_context(work_dir: &Path, messages: &[ChatMessage], state: &AgentState) { let ctx = SessionContext { tool_calls_used: state.tool_calls_used, tool_calls_total: state.tool_calls_total, @@ -490,6 +500,14 @@ fn load_session_context(work_dir: &Path) -> Option { } } +fn session_is_finished(ctx: &SessionContext) -> bool { + ctx.call_log + .last() + .and_then(|entry| entry.output.get("type")) + .and_then(|v| v.as_str()) + == Some("Finish") +} + // ─── Main ──────────────────────────────────────────────────────────────────── #[tokio::main] @@ -534,24 +552,22 @@ async fn main() { // Canonicalize all paths to absolute let data_dir = data_dir.canonicalize().unwrap_or_else(|e| { - eprintln!( - "[agent] ERROR: Cannot resolve data_dir: {}", - e - ); + eprintln!("[agent] ERROR: Cannot resolve data_dir: {}", e); std::process::exit(1); }); let benchmark_bin = benchmark_bin.canonicalize().unwrap_or_else(|e| { - eprintln!( - "[agent] ERROR: Cannot resolve benchmark_bin: {}", - e - ); + eprintln!("[agent] ERROR: Cannot resolve benchmark_bin: {}", e); std::process::exit(1); }); let bench_config = bench_tools::BenchConfig { benchmark_bin: benchmark_bin.clone(), data_dir: data_dir.clone(), - cpu_cores: if args.cpu_cores.is_empty() { None } else { Some(args.cpu_cores.clone()) }, + cpu_cores: if args.cpu_cores.is_empty() { + None + } else { + Some(args.cpu_cores.clone()) + }, }; // ── validate-perf mode: build, start server, run perf, report, exit ── @@ -561,7 +577,11 @@ async fn main() { let result = bench_tools::run_profiling(&work_dir, &bench_config, Some(5)).await; match &result { - ToolResult::RunProfiling { flamegraph_svg_path, top_functions, total_samples } => { + ToolResult::RunProfiling { + flamegraph_svg_path, + top_functions, + total_samples, + } => { eprintln!("[validate-perf] ✓ perf profiling succeeded!"); eprintln!("[validate-perf] Total samples: {}", total_samples); if !top_functions.is_empty() { @@ -633,24 +653,32 @@ async fn main() { // ── Resume from saved session context if requested ── let (mut messages, resumed) = if args.resume { match load_session_context(&work_dir) { - Some(ctx) if ctx.tool_calls_used < ctx.tool_calls_total => { - eprintln!( - "[agent] Resuming session: {}/{} tool calls used", - ctx.tool_calls_used, ctx.tool_calls_total - ); - state.tool_calls_used = ctx.tool_calls_used; - state.tool_calls_total = ctx.tool_calls_total; - state.last_benchmark = ctx.last_benchmark; - state.best_benchmark = ctx.best_benchmark; - state.call_log = ctx.call_log; - (ctx.messages, true) - } Some(ctx) => { - eprintln!( - "[agent] Session context found but already completed ({}/{} tool calls). Starting fresh.", - ctx.tool_calls_used, ctx.tool_calls_total - ); - (vec![], false) + let finished = session_is_finished(&ctx); + if ctx.tool_calls_used < ctx.tool_calls_total && !finished { + eprintln!( + "[agent] Resuming session: {}/{} tool calls used", + ctx.tool_calls_used, ctx.tool_calls_total + ); + state.tool_calls_used = ctx.tool_calls_used; + state.tool_calls_total = ctx.tool_calls_total; + state.last_benchmark = ctx.last_benchmark; + state.best_benchmark = ctx.best_benchmark; + state.call_log = ctx.call_log; + (ctx.messages, true) + } else { + if finished { + eprintln!( + "[agent] Session context found but already finished by finish() call. Starting fresh." + ); + } else { + eprintln!( + "[agent] Session context found but already completed ({}/{} tool calls). Starting fresh.", + ctx.tool_calls_used, ctx.tool_calls_total + ); + } + (vec![], false) + } } None => { eprintln!("[agent] No session context found. Starting fresh."); @@ -691,7 +719,8 @@ async fn main() { let finish_call = ToolCall::Finish { summary: "Tool call limit reached - auto finish".to_string(), }; - let result = dispatch_tool_call(&finish_call, &work_dir, &bench_config, &mut state).await; + let result = + dispatch_tool_call(&finish_call, &work_dir, &bench_config, &mut state).await; eprintln!("[agent] Final result: {:?}", result); logger.log_session_end( state.tool_calls_used, @@ -710,7 +739,10 @@ async fn main() { let elapsed = last.elapsed(); if elapsed < api_interval { let wait = api_interval - elapsed; - eprintln!("[agent] Rate limit: waiting {}ms before next API call", wait.as_millis()); + eprintln!( + "[agent] Rate limit: waiting {}ms before next API call", + wait.as_millis() + ); tokio::time::sleep(wait).await; } } @@ -752,19 +784,25 @@ async fn main() { // Process response // Merge reasoning fields: OpenAI/Kimi use "reasoning_content", Gemini uses "reasoning" - let reasoning_content = response_msg.reasoning_content.clone() + let reasoning_content = response_msg + .reasoning_content + .clone() .or_else(|| response_msg.reasoning.clone()); if let Some(tool_calls) = response_msg.tool_calls { if tool_calls.is_empty() { // No tool calls in response, treat as content-only logger.log_llm_response( - false, 0, + false, + 0, response_msg.content.as_deref(), reasoning_content.as_deref(), llm_duration_ms, ); if let Some(content) = &response_msg.content { - eprintln!("[agent] Assistant (no tools): {}", &content[..content.len().min(200)]); + eprintln!( + "[agent] Assistant (no tools): {}", + &content[..content.len().min(200)] + ); messages.push(build_assistant_content_message(content, reasoning_content)); save_session_context(&work_dir, &messages, &state); } else { @@ -790,7 +828,10 @@ async fn main() { ); // Append assistant message with tool calls - messages.push(build_assistant_tool_calls_message(tool_calls.clone(), reasoning_content)); + messages.push(build_assistant_tool_calls_message( + tool_calls.clone(), + reasoning_content, + )); for tc in &tool_calls { let tool_name = &tc.function.name; @@ -817,10 +858,12 @@ async fn main() { let parsed = parse_tool_call(tool_name, tool_args); let start = Instant::now(); - let (result, call_for_log) = match parsed { + let (result, call_for_log, parsed_is_finish) = match parsed { Ok(call) => { - let result = dispatch_tool_call(&call, &work_dir, &bench_config, &mut state).await; - (result, call) + let parsed_is_finish = matches!(&call, ToolCall::Finish { .. }); + let result = + dispatch_tool_call(&call, &work_dir, &bench_config, &mut state).await; + (result, call, parsed_is_finish) } Err(e) => { eprintln!("[agent] Parse error for tool '{}': {}", tool_name, e); @@ -831,6 +874,7 @@ async fn main() { ( result, ToolCall::GetStatus, // placeholder for logging + false, ) } }; @@ -863,7 +907,7 @@ async fn main() { save_session_context(&work_dir, &messages, &state); // Check if this was a finish call - if tool_name == "finish" { + if parsed_is_finish { eprintln!("[agent] Finish tool called. Ending session."); logger.log_session_end( state.tool_calls_used, @@ -885,7 +929,8 @@ async fn main() { summary: "Tool call limit reached - auto finish".to_string(), }; let finish_result = - dispatch_tool_call(&finish_call, &work_dir, &bench_config, &mut state).await; + dispatch_tool_call(&finish_call, &work_dir, &bench_config, &mut state) + .await; eprintln!("[agent] Final result: {:?}", finish_result); logger.log_session_end( state.tool_calls_used, @@ -903,13 +948,25 @@ async fn main() { } } else if let Some(content) = response_msg.content { // Assistant responded with text only (no tool calls) - logger.log_llm_response(false, 0, Some(&content), reasoning_content.as_deref(), llm_duration_ms); + logger.log_llm_response( + false, + 0, + Some(&content), + reasoning_content.as_deref(), + llm_duration_ms, + ); eprintln!("[agent] Assistant: {}", &content[..content.len().min(200)]); messages.push(build_assistant_content_message(&content, reasoning_content)); save_session_context(&work_dir, &messages, &state); } else { // Unexpected: no tool calls and no content - logger.log_llm_response(false, 0, None, reasoning_content.as_deref(), llm_duration_ms); + logger.log_llm_response( + false, + 0, + None, + reasoning_content.as_deref(), + llm_duration_ms, + ); eprintln!("[agent] Unexpected empty response from LLM. Ending session."); logger.log_session_end( state.tool_calls_used, @@ -944,7 +1001,10 @@ async fn main() { best.qps, best.recall, best.recall_passed ); } - eprintln!("[agent] Real-time log saved to: {}", logger.path().display()); + eprintln!( + "[agent] Real-time log saved to: {}", + logger.path().display() + ); } // ─── Tests ─────────────────────────────────────────────────────────────────── @@ -988,7 +1048,8 @@ mod tests { arguments: r#"{"path": "src/main.rs"}"#.to_string(), }, }]; - let msg = build_assistant_tool_calls_message(tool_calls.clone(), Some("reasoning".to_string())); + let msg = + build_assistant_tool_calls_message(tool_calls.clone(), Some("reasoning".to_string())); assert_eq!(msg.role, "assistant"); assert!(msg.content.is_none()); assert_eq!(msg.reasoning_content.as_deref(), Some("reasoning")); @@ -1057,11 +1118,8 @@ mod tests { #[test] fn test_parse_run_benchmark_with_params() { - let call = parse_tool_call( - "run_benchmark", - r#"{"concurrency": 8, "warmup": 500}"#, - ) - .unwrap(); + let call = + parse_tool_call("run_benchmark", r#"{"concurrency": 8, "warmup": 500}"#).unwrap(); match call { ToolCall::RunBenchmark { concurrency, @@ -1122,8 +1180,7 @@ mod tests { #[test] fn test_parse_finish() { - let call = - parse_tool_call("finish", r#"{"summary": "Optimized search"}"#).unwrap(); + let call = parse_tool_call("finish", r#"{"summary": "Optimized search"}"#).unwrap(); match call { ToolCall::Finish { summary } => assert_eq!(summary, "Optimized search"), _ => panic!("Expected Finish"), @@ -1377,4 +1434,46 @@ mod tests { let _ = std::fs::remove_dir_all(&dir); } + + #[test] + fn test_session_is_finished_detects_finish_output() { + let ctx = SessionContext { + tool_calls_used: 3, + tool_calls_total: 50, + messages: vec![], + last_benchmark: None, + best_benchmark: None, + call_log: vec![state::ToolCallLog { + index: 3, + tool: "finish".to_string(), + input: serde_json::json!({"tool": "finish"}), + output: serde_json::json!({"type": "Finish", "status": "done"}), + duration_ms: 1, + timestamp: chrono::Utc::now(), + }], + }; + + assert!(session_is_finished(&ctx)); + } + + #[test] + fn test_session_is_finished_false_for_non_finish_output() { + let ctx = SessionContext { + tool_calls_used: 3, + tool_calls_total: 50, + messages: vec![], + last_benchmark: None, + best_benchmark: None, + call_log: vec![state::ToolCallLog { + index: 3, + tool: "finish".to_string(), + input: serde_json::json!({"tool": "finish"}), + output: serde_json::json!({"type": "Error", "message": "parse failed"}), + duration_ms: 1, + timestamp: chrono::Utc::now(), + }], + }; + + assert!(!session_is_finished(&ctx)); + } } diff --git a/scripts/generate_ground_truth.py b/scripts/generate_ground_truth.py index 821f95e..eacabc1 100644 --- a/scripts/generate_ground_truth.py +++ b/scripts/generate_ground_truth.py @@ -333,6 +333,27 @@ def _chunk_path(data_dir, chunk_id): return os.path.join(data_dir, f"ground_truth_chunk_{chunk_id}.json") +def _is_valid_chunk(existing_chunk, q_start, q_end, top_k): + """Validate whether a previously saved chunk can be reused safely.""" + expected_len = q_end - q_start + if not isinstance(existing_chunk, list) or len(existing_chunk) != expected_len: + return False + + for offset, row in enumerate(existing_chunk): + if not isinstance(row, dict): + return False + expected_query_id = q_start + offset + if row.get("query_id") != expected_query_id: + return False + neighbors = row.get("neighbors") + if not isinstance(neighbors, list) or len(neighbors) != top_k: + return False + if not all(isinstance(n, int) for n in neighbors): + return False + + return True + + def main(): parser = argparse.ArgumentParser( description="Generate ground truth via brute-force L2 nearest neighbor search." @@ -429,6 +450,13 @@ def main(): if use_numpy: del base_vectors, query_vectors + # Convert base data once for worker transport. + # Avoid repeating base_data.tolist() per chunk (huge memory overhead). + if use_numpy: + base_list = base_data.tolist() + else: + base_list = base_data + os.makedirs(data_dir, exist_ok=True) # Build chunk task list, skipping already-completed chunks (resume support) @@ -443,11 +471,11 @@ def main(): all_chunk_paths.append(cp) if os.path.isfile(cp): - # Validate existing chunk has correct number of entries + # Validate existing chunk can be reused for current params. try: with open(cp, "r", encoding="utf-8") as f: existing = json.load(f) - if len(existing) == (q_end - q_start): + if _is_valid_chunk(existing, q_start, q_end, top_k): skipped += 1 continue except (json.JSONDecodeError, OSError): @@ -455,10 +483,8 @@ def main(): if use_numpy: query_chunk = query_data[q_start:q_end].tolist() - base_list = base_data.tolist() else: query_chunk = query_data[q_start:q_end] - base_list = base_data pending_tasks.append( (base_list, query_chunk, top_k, chunk_id, q_start, cp, use_numpy) diff --git a/scripts/run_eval.sh b/scripts/run_eval.sh index 7bfa463..7568614 100644 --- a/scripts/run_eval.sh +++ b/scripts/run_eval.sh @@ -261,145 +261,150 @@ step_collect_results() { # Build a combined result file with model metadata. # Uses best_benchmark (tracked across all runs) as primary source, # falls back to last_benchmark, then scans benchmarks/*.json files. - python3 -c " -import json, sys, os, glob + python3 - "${eval_log}" "${MODEL_NAME}" "${MODEL_ID}" "${WORK_DIR}" "${result_file}" <<'PY' || die "Failed to collect results" +import glob +import json +import os +import sys -eval_log_path = '${eval_log}' -model_name = '${MODEL_NAME}' -model_id = '${MODEL_ID}' -work_dir = '${WORK_DIR}' +if len(sys.argv) != 6: + raise SystemExit("Expected arguments: eval_log model_name model_id work_dir result_file") -with open(eval_log_path, 'r') as f: +eval_log_path, model_name, model_id, work_dir, result_file = sys.argv[1:6] + +with open(eval_log_path, "r", encoding="utf-8") as f: eval_log = json.load(f) result = { - 'model_name': model_name, - 'model_id': model_id, - 'eval_log': eval_log, + "model_name": model_name, + "model_id": model_id, + "eval_log": eval_log, } -# Priority 1: best_benchmark from eval_log (tracked across all runs by agent) -best_bench = eval_log.get('best_benchmark') - -# Priority 2: last_benchmark (from the final finish() call) -last_bench = eval_log.get('last_benchmark') -# Priority 3: scan all benchmarks/*.json for the highest QPS with passing recall def scan_benchmark_files(): - bench_dir = os.path.join(work_dir, 'benchmarks') + bench_dir = os.path.join(work_dir, "benchmarks") if not os.path.isdir(bench_dir): return None best = None - for f in sorted(glob.glob(os.path.join(bench_dir, 'benchmark_*.json'))): + for path in sorted(glob.glob(os.path.join(bench_dir, "benchmark_*.json"))): try: - with open(f, 'r') as fh: - b = json.load(fh) - if b.get('recall_passed', False) and b.get('qps', 0) > 0: - if best is None or b['qps'] > best['qps']: - best = b + with open(path, "r", encoding="utf-8") as f: + bench = json.load(f) + if bench.get("recall_passed", False) and bench.get("qps", 0) > 0: + if best is None or bench["qps"] > best["qps"]: + best = bench except (json.JSONDecodeError, KeyError): continue return best -# Pick the best available result + +best_bench = eval_log.get("best_benchmark") +last_bench = eval_log.get("last_benchmark") chosen = None -chosen_source = 'none' +chosen_source = "none" -if best_bench and best_bench.get('recall_passed', False) and best_bench.get('qps', 0) > 0: +if best_bench and best_bench.get("recall_passed", False) and best_bench.get("qps", 0) > 0: chosen = best_bench - chosen_source = 'best_benchmark' + chosen_source = "best_benchmark" -if last_bench and last_bench.get('recall_passed', False) and last_bench.get('qps', 0) > 0: - if chosen is None or last_bench['qps'] > chosen['qps']: +if last_bench and last_bench.get("recall_passed", False) and last_bench.get("qps", 0) > 0: + if chosen is None or last_bench["qps"] > chosen["qps"]: chosen = last_bench - chosen_source = 'last_benchmark' + chosen_source = "last_benchmark" scanned = scan_benchmark_files() -if scanned: - if chosen is None or scanned['qps'] > chosen['qps']: - chosen = scanned - chosen_source = 'benchmark_files_scan' +if scanned and (chosen is None or scanned["qps"] > chosen["qps"]): + chosen = scanned + chosen_source = "benchmark_files_scan" if chosen: - result['final_benchmark'] = chosen - result['qps'] = chosen.get('qps', 0) - result['recall'] = chosen.get('recall', 0) - result['recall_passed'] = chosen.get('recall_passed', False) - result['result_source'] = chosen_source - print(f'[collect] Using {chosen_source}: QPS={chosen[\"qps\"]:.2f}, Recall={chosen[\"recall\"]:.4f}', file=sys.stderr) + result["final_benchmark"] = chosen + result["qps"] = chosen.get("qps", 0) + result["recall"] = chosen.get("recall", 0) + result["recall_passed"] = chosen.get("recall_passed", False) + result["result_source"] = chosen_source + print( + f"[collect] Using {chosen_source}: QPS={chosen['qps']:.2f}, Recall={chosen['recall']:.4f}", + file=sys.stderr, + ) else: - result['qps'] = 0 - result['recall'] = 0 - result['recall_passed'] = False - result['result_source'] = 'none' - print('[collect] No valid benchmark result found', file=sys.stderr) + result["qps"] = 0 + result["recall"] = 0 + result["recall_passed"] = False + result["result_source"] = "none" + print("[collect] No valid benchmark result found", file=sys.stderr) -with open('${result_file}', 'w') as f: +with open(result_file, "w", encoding="utf-8") as f: json.dump(result, f, indent=2) -print(json.dumps({ - 'model': model_name, - 'qps': result['qps'], - 'recall': result['recall'], - 'recall_passed': result['recall_passed'], - 'tool_calls_used': eval_log.get('tool_calls_used', 0), - 'result_source': result.get('result_source', 'none'), -}, indent=2)) -" || die "Failed to collect results" +print( + json.dumps( + { + "model": model_name, + "qps": result["qps"], + "recall": result["recall"], + "recall_passed": result["recall_passed"], + "tool_calls_used": eval_log.get("tool_calls_used", 0), + "result_source": result.get("result_source", "none"), + }, + indent=2, + ) +) +PY log " Results saved to ${result_file}" # Update leaderboard local leaderboard="${RESULTS_DIR}/leaderboard.json" - python3 -c " -import json, sys + python3 - "${result_file}" "${leaderboard}" <<'PY' || die "Failed to update leaderboard" +import json +import sys from datetime import datetime, timezone -result_path = '${result_file}' -leaderboard_path = '${leaderboard}' +if len(sys.argv) != 3: + raise SystemExit("Expected arguments: result_file leaderboard_path") + +result_path, leaderboard_path = sys.argv[1:3] -with open(result_path, 'r') as f: +with open(result_path, "r", encoding="utf-8") as f: result = json.load(f) -# Load existing leaderboard or start fresh try: - with open(leaderboard_path, 'r') as f: + with open(leaderboard_path, "r", encoding="utf-8") as f: entries = json.load(f) except (FileNotFoundError, json.JSONDecodeError): entries = [] -# Compute final score: QPS = 0 if recall < 0.95 -qps = result.get('qps', 0) -recall = result.get('recall', 0) +qps = result.get("qps", 0) +recall = result.get("recall", 0) if recall < 0.95: qps = 0 entry = { - 'model_name': result['model_name'], - 'qps': qps, - 'recall': recall, - 'recall_passed': result.get('recall_passed', False), - 'tool_calls_used': result.get('eval_log', {}).get('tool_calls_used', 0), - 'result_source': result.get('result_source', 'unknown'), - 'timestamp': datetime.now(timezone.utc).isoformat(), + "model_name": result["model_name"], + "qps": qps, + "recall": recall, + "recall_passed": result.get("recall_passed", False), + "tool_calls_used": result.get("eval_log", {}).get("tool_calls_used", 0), + "result_source": result.get("result_source", "unknown"), + "timestamp": datetime.now(timezone.utc).isoformat(), } entries.append(entry) +entries.sort(key=lambda item: (-item["qps"], -item["recall"])) -# Sort: QPS descending, then recall descending for ties -entries.sort(key=lambda e: (-e['qps'], -e['recall'])) - -with open(leaderboard_path, 'w') as f: +with open(leaderboard_path, "w", encoding="utf-8") as f: json.dump(entries, f, indent=2) print() -print('=== Leaderboard ===') -for i, e in enumerate(entries): - marker = ' <-- NEW' if e['model_name'] == result['model_name'] and e['timestamp'] == entry['timestamp'] else '' - src = f\" [{e.get('result_source', '?')}]\" if e.get('result_source') else '' - print(f\" {i+1}. {e['model_name']:20s} QPS: {e['qps']:>10.2f} Recall: {e['recall']:.4f}{src}{marker}\") +print("=== Leaderboard ===") +for i, item in enumerate(entries): + marker = " <-- NEW" if item["model_name"] == result["model_name"] and item["timestamp"] == entry["timestamp"] else "" + src = f" [{item.get('result_source', '?')}]" if item.get("result_source") else "" + print(f" {i+1}. {item['model_name']:20s} QPS: {item['qps']:>10.2f} Recall: {item['recall']:.4f}{src}{marker}") print() -" || die "Failed to update leaderboard" +PY log " Leaderboard updated at ${leaderboard}" } diff --git a/scripts/test_generate_ground_truth.py b/scripts/test_generate_ground_truth.py index b1e869c..14af2ff 100644 --- a/scripts/test_generate_ground_truth.py +++ b/scripts/test_generate_ground_truth.py @@ -5,6 +5,7 @@ import math import os import struct +import subprocess import sys import tempfile import unittest @@ -14,6 +15,7 @@ read_fvecs, compute_ground_truth, _l2_distance_squared_python, + _is_valid_chunk, ) @@ -132,6 +134,29 @@ def test_file_not_found(self): read_fvecs("/nonexistent/path.fvecs") +class TestChunkValidation(unittest.TestCase): + def test_chunk_validation_rejects_wrong_top_k(self): + chunk = [ + {"query_id": 10, "neighbors": [1, 2]}, + {"query_id": 11, "neighbors": [3, 4]}, + ] + self.assertFalse(_is_valid_chunk(chunk, 10, 12, top_k=1)) + + def test_chunk_validation_rejects_wrong_query_id(self): + chunk = [ + {"query_id": 10, "neighbors": [1]}, + {"query_id": 99, "neighbors": [2]}, + ] + self.assertFalse(_is_valid_chunk(chunk, 10, 12, top_k=1)) + + def test_chunk_validation_accepts_matching_chunk(self): + chunk = [ + {"query_id": 10, "neighbors": [1]}, + {"query_id": 11, "neighbors": [2]}, + ] + self.assertTrue(_is_valid_chunk(chunk, 10, 12, top_k=1)) + + class TestEndToEndSmall(unittest.TestCase): """End-to-end test: write fvecs, compute ground truth, verify JSON output.""" @@ -175,6 +200,64 @@ def test_small_dataset(self): import shutil shutil.rmtree(tmpdir) + def test_resume_recompute_when_top_k_changes(self): + tmpdir = tempfile.mkdtemp() + try: + base = [[0.0, 0.0], [1.0, 0.0], [10.0, 10.0]] + queries = [[0.1, 0.1], [9.9, 9.9]] + + write_fvecs(os.path.join(tmpdir, "sift_base.fvecs"), base) + write_fvecs(os.path.join(tmpdir, "sift_query.fvecs"), queries) + + script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_ground_truth.py") + + subprocess.run( + [ + sys.executable, + script_path, + "--data-dir", tmpdir, + "--top-k", "2", + "--chunk-size", "1", + "--workers", "1", + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + with open(os.path.join(tmpdir, "ground_truth.json"), "r", encoding="utf-8") as f: + stale = json.load(f) + + for chunk_id, row in enumerate(stale): + with open( + os.path.join(tmpdir, f"ground_truth_chunk_{chunk_id}.json"), + "w", + encoding="utf-8", + ) as f: + json.dump([row], f) + + subprocess.run( + [ + sys.executable, + script_path, + "--data-dir", tmpdir, + "--top-k", "1", + "--chunk-size", "1", + "--workers", "1", + ], + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + with open(os.path.join(tmpdir, "ground_truth.json"), "r", encoding="utf-8") as f: + updated = json.load(f) + + self.assertTrue(all(len(entry["neighbors"]) == 1 for entry in updated)) + finally: + import shutil + shutil.rmtree(tmpdir) + if __name__ == "__main__": unittest.main()