diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aca79d8..0c8d9e9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -52,7 +52,7 @@ jobs: stdin-integration: name: CLI stdin integration (${{ matrix.os }}) runs-on: ${{ matrix.os }} - timeout-minutes: 20 + timeout-minutes: 30 strategy: fail-fast: false matrix: @@ -70,6 +70,9 @@ jobs: rustup default 1.92.0 rustup show active-toolchain + - name: Cache cargo artifacts + uses: Swatinem/rust-cache@v2 + - name: Run stdin integration tests run: cargo test -p wraithrun --test stdin_integration @@ -214,27 +217,31 @@ jobs: - name: Checkout uses: actions/checkout@v6 - - name: Install Rust toolchain (powershell) + - name: Install Rust toolchain shell: powershell run: | $cargobin = "$env:USERPROFILE\.cargo\bin" - if (-not ($env:Path -split ';' | Where-Object { $_ -eq $cargobin })) { - $env:Path = "$cargobin;$env:Path" - echo "$cargobin" >> $env:GITHUB_PATH - } + $env:Path = "$cargobin;$env:Path" if (-not (Get-Command rustup -ErrorAction SilentlyContinue)) { Write-Host "rustup not found, installing..." Invoke-WebRequest -Uri https://win.rustup.rs/x86_64 -OutFile rustup-init.exe .\rustup-init.exe -y --default-toolchain 1.92.0 --profile minimal Remove-Item .\rustup-init.exe - $env:Path = "$cargobin;$env:Path" } rustup toolchain install 1.92.0 --profile minimal rustup default 1.92.0 - rustup show active-toolchain + # Resolve the real toolchain bin directory (cargo proxy may be absent) + $realBin = Split-Path (& rustup which cargo) + Write-Host "Toolchain bin: $realBin" + $env:Path = "$realBin;$env:Path" + cargo --version + $utf8 = New-Object System.Text.UTF8Encoding($false) + [IO.File]::AppendAllText($env:GITHUB_ENV, "CARGO_BIN=$realBin`n", $utf8) + [IO.File]::AppendAllText($env:GITHUB_PATH, "$realBin`n", $utf8) - name: Cache cargo artifacts uses: Swatinem/rust-cache@v2 + continue-on-error: true - name: Validate live e2e fixture configuration shell: powershell @@ -247,6 +254,9 @@ jobs: - name: Run live success e2e test (no fallback) shell: powershell run: | + $bin = if ($env:CARGO_BIN) { $env:CARGO_BIN } else { "$env:USERPROFILE\.cargo\bin" } + $env:Path = "$bin;$env:Path" + Write-Host "cargo at: $(Get-Command cargo -ErrorAction SilentlyContinue | Select-Object -ExpandProperty Source)" cargo test -p wraithrun --features inference_bridge/onnx --test stdin_integration live_mode_e2e_success_without_fallback_when_fixture_is_configured -- --exact --nocapture - name: Upload live success e2e artifacts diff --git a/.gitignore b/.gitignore index 311c0e0..42ae881 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,9 @@ Thumbs.db /launch-assets/generated/ /launch-assets/reports/ /launch-assets/*.json +/test_outputs/*.db +/test_outputs/*.db-shm +/test_outputs/*.db-wal # Local GitHub Actions runner /actions-runner/ diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 6835b47..24aeb07 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -5,7 +5,7 @@ edition.workspace = true license.workspace = true [features] -default = [] +default = ["onnx"] onnx = ["inference_bridge/onnx"] vitis = ["inference_bridge/vitis"] directml = ["inference_bridge/directml"] diff --git a/cli/src/main.rs b/cli/src/main.rs index b43afd7..e7c37b0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -294,6 +294,13 @@ fn validate_live_runtime_preflight(runtime: &RuntimeConfig) -> Result<()> { return Ok(()); } + // Fail fast when the binary was compiled without inference support (#149). + #[cfg(not(feature = "onnx"))] + bail!( + "Live inference requested but this binary was built without inference support. \ + Rebuild with `--features onnx` (or `--features vitis`/`--features directml`)." + ); + if !runtime.model.is_file() { bail!( "Live mode model file not found: {}. Run '--doctor --live --introspection-format json' (or '--doctor --live --fix') and provide a readable --model path.", diff --git a/core_engine/src/agent.rs b/core_engine/src/agent.rs index 4f117c1..cc75b9c 100644 --- a/core_engine/src/agent.rs +++ b/core_engine/src/agent.rs @@ -351,9 +351,12 @@ impl Agent { fn check_tool_precondition(&self, tool_name: &str) -> bool { match tool_name { "read_syslog" => { - // Default path is ./agent.log — skip if it doesn't exist and - // the sandbox policy would deny access anyway. - let default_path = std::path::Path::new("./agent.log"); + // Use a platform-appropriate default log path (#153). + let default_path = if cfg!(target_os = "windows") { + std::path::Path::new("C:\\Windows\\System32\\winevt\\Logs\\System.evtx") + } else { + std::path::Path::new("/var/log/syslog") + }; if !default_path.exists() { return false; } diff --git a/core_engine/src/lib.rs b/core_engine/src/lib.rs index 262569e..4b6e501 100644 --- a/core_engine/src/lib.rs +++ b/core_engine/src/lib.rs @@ -1,6 +1,6 @@ pub mod agent; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use inference_bridge::ModelCapabilityProbe; use serde::{Deserialize, Serialize, Serializer}; @@ -101,7 +101,9 @@ static BUILTIN_TEMPLATES: [InvestigationTemplate; 7] = [ // ── Capability tiering thresholds (const, easy to tune) ── /// Models below this parameter count (billions) are classified as Basic. -const PARAM_BASIC_CEILING_B: f32 = 2.0; +/// Lowered from 2.0 → 1.0 so common 1B+ models (Qwen2.5-0.5B overestimates +/// to ~1.4B, Llama-3.2-1B at ~1.12B) reach Moderate tier and use inference (#157). +const PARAM_BASIC_CEILING_B: f32 = 1.0; /// Models above this parameter count (billions) are classified as Strong. const PARAM_STRONG_FLOOR_B: f32 = 10.0; /// Latency above this (ms/tok) demotes to Basic. @@ -1179,10 +1181,19 @@ pub fn basic_tier_summary_for_task(findings: &[Finding], task: Option<&str>) -> return format!("{prefix}\nFINDINGS:\n(none)\nRISK: info\nACTIONS:\n(none)"); } - let max_sev = findings - .iter() + // Sort findings: highest severity first, then by confidence descending (#155). + let mut sorted: Vec<&Finding> = findings.iter().collect(); + sorted.sort_by(|a, b| { + b.severity.cmp(&a.severity).then_with(|| { + b.confidence + .partial_cmp(&a.confidence) + .unwrap_or(std::cmp::Ordering::Equal) + }) + }); + + let max_sev = sorted + .first() .map(|f| f.severity) - .max() .unwrap_or(FindingSeverity::Info); let distinct_tools: HashSet<&str> = findings @@ -1190,34 +1201,97 @@ pub fn basic_tier_summary_for_task(findings: &[Finding], task: Option<&str>) -> .filter_map(|f| f.evidence_pointer.tool.as_deref()) .collect(); + // -- Header -- let mut out = match task { Some(t) => format!( - "SUMMARY: Task \"{t}\" produced {} findings across {} tool(s). Maximum severity: {}.\nFINDINGS:\n", - findings.len(), - distinct_tools.len(), - max_sev.token() + "INVESTIGATION SUMMARY — \"{t}\"\n\ + {total} findings across {tools} tool(s). Maximum severity: {sev}.\n\n", + total = findings.len(), + tools = distinct_tools.len(), + sev = max_sev.token().to_uppercase() ), None => format!( - "SUMMARY: {} findings detected. Maximum severity: {}.\nFINDINGS:\n", - findings.len(), - max_sev.token() + "INVESTIGATION SUMMARY\n\ + {total} findings detected. Maximum severity: {sev}.\n\n", + total = findings.len(), + sev = max_sev.token().to_uppercase() ), }; - for (i, f) in findings.iter().enumerate() { - out.push_str(&format!( - "{}. {} [{}] — {}\n", - i + 1, - f.title, - f.severity.token(), - f.recommended_action - )); + // -- Group by severity (highest first) -- + let severity_order = [ + FindingSeverity::Critical, + FindingSeverity::High, + FindingSeverity::Medium, + FindingSeverity::Low, + FindingSeverity::Info, + ]; + + for &sev in &severity_order { + let group: Vec<&&Finding> = sorted.iter().filter(|f| f.severity == sev).collect(); + if group.is_empty() { + continue; + } + let header = match sev { + FindingSeverity::Critical => "🔴 CRITICAL", + FindingSeverity::High => "🟠 HIGH", + FindingSeverity::Medium => "🟡 MEDIUM", + FindingSeverity::Low => "🔵 LOW", + FindingSeverity::Info => "ℹ️ INFO", + }; + out.push_str(&format!("── {} ({}) ──\n", header, group.len())); + for f in &group { + let tool_tag = f + .evidence_pointer + .tool + .as_deref() + .map(|t| format!(" [{}]", t)) + .unwrap_or_default(); + out.push_str(&format!( + " • {}{} — {}\n", + f.title, tool_tag, f.recommended_action + )); + } + out.push('\n'); } - out.push_str(&format!("RISK: {}\nACTIONS:\n", max_sev.token())); + // -- Cross-references: find tools whose findings overlap -- + if distinct_tools.len() > 1 { + // Collect tool→titles mapping for cross-reference hints. + let mut tool_titles: HashMap<&str, Vec<&str>> = HashMap::new(); + for f in &sorted { + if let Some(tool) = f.evidence_pointer.tool.as_deref() { + tool_titles.entry(tool).or_default().push(&f.title); + } + } + if tool_titles.len() > 1 { + out.push_str("CROSS-REFERENCES:\n"); + let tools_vec: Vec<&&str> = tool_titles.keys().collect(); + out.push_str(&format!( + " Data was collected from {} sources ({}). ", + tools_vec.len(), + tools_vec.iter().map(|t| **t).collect::>().join(", ") + )); + out.push_str("Correlate findings across tools for a complete picture.\n\n"); + } + } - for (i, f) in findings.iter().enumerate() { - out.push_str(&format!("{}. {}\n", i + 1, f.recommended_action)); + // -- Risk assessment -- + out.push_str(&format!( + "OVERALL RISK: {}\n\n", + max_sev.token().to_uppercase() + )); + + // -- Prioritized actions (deduplicated, urgent first) -- + out.push_str("RECOMMENDED ACTIONS (priority order):\n"); + let mut seen_actions: HashSet<&str> = HashSet::new(); + let mut action_idx = 0usize; + for f in &sorted { + let action = f.recommended_action.as_str(); + if seen_actions.insert(action) { + action_idx += 1; + out.push_str(&format!(" {}. {}\n", action_idx, action)); + } } // Remove trailing newline for clean output. @@ -1733,8 +1807,9 @@ mod tests { #[test] fn classify_small_model_as_basic() { + // With PARAM_BASIC_CEILING_B = 1.0 (#157), a 0.5B model is Basic. let probe = ModelCapabilityProbe { - estimated_param_billions: 1.2, + estimated_param_billions: 0.5, execution_provider: "CPUExecutionProvider".to_string(), smoke_latency_ms: 80, vocab_size: 32000, @@ -1742,6 +1817,18 @@ mod tests { assert_eq!(classify_capability(&probe), ModelCapabilityTier::Basic); } + #[test] + fn classify_1b_model_as_moderate() { + // With PARAM_BASIC_CEILING_B = 1.0 (#157), a 1.2B model reaches Moderate. + let probe = ModelCapabilityProbe { + estimated_param_billions: 1.2, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 80, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Moderate); + } + #[test] fn classify_medium_model_moderate_latency_as_moderate() { let probe = ModelCapabilityProbe { @@ -1859,12 +1946,15 @@ mod tests { ]; let summary = basic_tier_summary(&findings); - assert!(summary.starts_with("SUMMARY: 2 findings detected. Maximum severity: high.")); - assert!(summary.contains("FINDINGS:")); - assert!(summary.contains("1. Active listeners [high]")); - assert!(summary.contains("2. Suspicious persistence [medium]")); - assert!(summary.contains("RISK: high")); - assert!(summary.contains("ACTIONS:")); + // Updated format groups by severity and provides cross-references (#155). + assert!(summary.contains("INVESTIGATION SUMMARY")); + assert!(summary.contains("2 findings")); + assert!(summary.contains("HIGH")); + assert!(summary.contains("Active listeners")); + assert!(summary.contains("Suspicious persistence")); + assert!(summary.contains("OVERALL RISK: HIGH")); + assert!(summary.contains("RECOMMENDED ACTIONS")); + assert!(summary.contains("CROSS-REFERENCES")); } #[test] @@ -1891,8 +1981,10 @@ mod tests { "observation.indicator_count", )]; let summary = basic_tier_summary_for_task(&findings, Some("windows-triage")); - assert!(summary.contains("Task \"windows-triage\"")); + // Updated format includes task name in header (#155). + assert!(summary.contains("windows-triage")); assert!(summary.contains("1 tool(s)")); + assert!(summary.contains("RECOMMENDED ACTIONS")); } // ── Discrete confidence label tests (#85) ── diff --git a/cyber_tools/src/lib.rs b/cyber_tools/src/lib.rs index 3bd5862..ea71474 100644 --- a/cyber_tools/src/lib.rs +++ b/cyber_tools/src/lib.rs @@ -81,10 +81,12 @@ impl SandboxPolicy { } #[cfg(target_os = "windows")] - let command_allowlist: HashSet = ["whoami", "netstat", "net", "tasklist", "reg"] - .into_iter() - .map(|c| c.to_string()) - .collect(); + let command_allowlist: HashSet = [ + "whoami", "netstat", "net", "tasklist", "reg", "sc", "wmic", "schtasks", + ] + .into_iter() + .map(|c| c.to_string()) + .collect(); #[cfg(not(target_os = "windows"))] let command_allowlist: HashSet = ["id", "ss", "sudo"] diff --git a/inference_bridge/src/lib.rs b/inference_bridge/src/lib.rs index 65414b4..4950c88 100644 --- a/inference_bridge/src/lib.rs +++ b/inference_bridge/src/lib.rs @@ -495,8 +495,16 @@ impl OnnxVitisEngine { format!(r#"{{"tool":"hash_binary","args":{{"path":"{path}"}}}}"#) } "read_syslog" => { + // Use a platform-appropriate default log path instead of a + // project file like README.md, which would produce bogus + // findings (#153). + let default_path = if cfg!(target_os = "windows") { + "C:\\Windows\\System32\\winevt\\Logs\\System.evtx".to_string() + } else { + "/var/log/syslog".to_string() + }; let path = Self::guess_path_from_task(task) - .unwrap_or_else(|| "./README.md".to_string()); + .unwrap_or(default_path); let path = Self::escape_json_string(&path); let max_lines = Self::guess_line_count_from_task(task).unwrap_or(200); format!( diff --git a/inference_bridge/src/onnx_vitis.rs b/inference_bridge/src/onnx_vitis.rs index b3a6079..796c3e4 100644 --- a/inference_bridge/src/onnx_vitis.rs +++ b/inference_bridge/src/onnx_vitis.rs @@ -14,6 +14,7 @@ use std::{ collections::{HashMap, HashSet}, ffi::CString, fs, + io::Read, path::PathBuf, time::Instant, }; @@ -75,7 +76,7 @@ impl RuntimeCompatibilityReport { .any(|issue| issue.severity == RuntimeCompatibilitySeverity::Fail) } - #[cfg(any(not(feature = "onnx"), test))] + #[cfg(not(feature = "onnx"))] fn push_warn(&mut self, reason_code: &'static str, detail: impl Into) { self.issues.push(RuntimeCompatibilityIssue { severity: RuntimeCompatibilitySeverity::Warn, @@ -255,7 +256,12 @@ fn classify_session_init_reason_code(error_text: &str) -> &'static str { return "runtime_vitis_provider_missing"; } - if normalized.contains("onnxruntime.dll") && normalized.contains("not found") { + if (normalized.contains("onnxruntime.dll") + || normalized.contains("libonnxruntime.so") + || normalized.contains("libonnxruntime.dylib") + || normalized.contains("onnx runtime library")) + && normalized.contains("not found") + { return "runtime_ort_dylib_missing"; } @@ -1370,9 +1376,86 @@ fn build_session_with_vitis_cascade(config: &ModelConfig) -> Result { } } +#[cfg(feature = "onnx")] +fn ensure_ort_dylib_available() -> Result<()> { + if let Some(path) = std::env::var_os("ORT_DYLIB_PATH") { + let p = PathBuf::from(&path); + if p.is_file() { + return Ok(()); + } + bail!( + "ORT_DYLIB_PATH is set to '{}' but the file does not exist; \ + install ONNX Runtime or set ORT_DYLIB_PATH / WRAITHRUN_ORT_DYLIB_PATH \ + to a valid onnxruntime library path", + p.display() + ); + } + + // No explicit path — probe system search locations. + let lib_names: &[&str] = if cfg!(windows) { + &["onnxruntime.dll"] + } else if cfg!(target_os = "macos") { + &["libonnxruntime.dylib"] + } else { + &["libonnxruntime.so"] + }; + + let search_dirs: Vec = std::env::var_os("PATH") + .map(|val| std::env::split_paths(&val).collect()) + .unwrap_or_default(); + + for dir in &search_dirs { + for name in lib_names { + if dir.join(name).is_file() { + return Ok(()); + } + } + } + + bail!( + "ONNX Runtime library ({}) not found on PATH or via ORT_DYLIB_PATH; \ + install ONNX Runtime or set ORT_DYLIB_PATH / WRAITHRUN_ORT_DYLIB_PATH", + lib_names.join(" / ") + ); +} + +/// Fast pre-validation: confirms the model file starts with a plausible +/// protobuf field tag (ONNX models are serialized `ModelProto` messages). +/// This prevents feeding garbage bytes to `commit_from_file`, which can +/// block indefinitely in the ONNX Runtime graph optimiser. +#[cfg(feature = "onnx")] +fn validate_model_preamble(path: &std::path::Path) -> Result<()> { + let mut file = fs::File::open(path)?; + let mut buf = [0u8; 4]; + let n = file.read(&mut buf)?; + if n < 2 { + bail!( + "invalid model file '{}': too small ({n} bytes)", + path.display() + ); + } + // Protobuf field tags: low 3 bits = wire type, remaining bits = field number. + // Valid wire types for ModelProto fields: 0 (Varint), 1 (64-bit), + // 2 (Length-delimited), 5 (32-bit). Wire types 3, 4 are deprecated; + // 6, 7 are invalid. + let wire_type = buf[0] & 0x07; + let field_number = buf[0] >> 3; + if field_number == 0 || wire_type == 3 || wire_type == 4 || wire_type > 5 { + bail!( + "invalid model file '{}': does not start with a valid ONNX protobuf header \ + (first byte 0x{:02X} is not a valid protobuf field tag)", + path.display(), + buf[0] + ); + } + Ok(()) +} + #[cfg(feature = "onnx")] fn build_session(config: &ModelConfig) -> Result { configure_ort_dylib_path(config); + ensure_ort_dylib_available()?; + validate_model_preamble(&config.model_path)?; // When the vitis feature is available, delegate to the Vitis EP cascade // unless the user explicitly forces CPU-only mode. @@ -2452,6 +2535,7 @@ pub fn inspect_runtime_compatibility( } #[cfg(feature = "onnx")] +#[allow(clippy::too_many_arguments)] fn run_prompt_shared_buffer( session: &mut Session, layout: &SessionLayout, @@ -2712,6 +2796,10 @@ fn run_prompt_on_session( bail!("prompt encoding produced no token IDs"); } + // Track initial KV-cache padding length so the decode loop can account + // for it in the attention mask (#147). + let mut initial_cache_len: usize = 0; + if cache_enabled { // --- Prefix cache reuse (#65) --- // Find how many leading tokens match the previous prompt. @@ -2738,13 +2826,8 @@ fn run_prompt_on_session( if !suffix.is_empty() { let prefill_started = Instant::now(); let attention_len = context_ids.len(); - let model_inputs = build_model_inputs( - &cache.layout, - &suffix.to_vec(), - attention_len, - true, - &cache_state, - )?; + let model_inputs = + build_model_inputs(&cache.layout, suffix, attention_len, true, &cache_state)?; let mut outputs = ort_result(cache.session.run(model_inputs))?; debug!( suffix_tokens = suffix.len(), @@ -2782,19 +2865,18 @@ fn run_prompt_on_session( // Same fix as run_prompt: account for forced cache padding when // cache tensors will be included during prefill (#136). - let initial_cache_len: usize = - if cache.layout.use_cache.is_none() && !cache_state.is_empty() { - cache_state - .values() - .next() - .and_then(|v| { - let spec = cache.layout.cache_specs.first()?; - v.shape().get(spec.past_axis).copied().map(|d| d as usize) - }) - .unwrap_or(0) - } else { - 0 - }; + initial_cache_len = if cache.layout.use_cache.is_none() && !cache_state.is_empty() { + cache_state + .values() + .next() + .and_then(|v| { + let spec = cache.layout.cache_specs.first()?; + v.shape().get(spec.past_axis).copied().map(|d| d as usize) + }) + .unwrap_or(0) + } else { + 0 + }; let attention_len = context_ids.len() + initial_cache_len; debug!( @@ -2847,13 +2929,16 @@ fn run_prompt_on_session( for step in 0..config.max_new_tokens.max(1) { let step_started = Instant::now(); let (decode_with_cache, step_input_ids, attention_len) = if cache_enabled { - // attention_len = past KV-cache entries + 1 current decode token (#114). + // attention_len = past KV-cache entries (context_ids tracks total + // processed tokens; initial_cache_len accounts for prior cache + // padding). The model internally handles the current decode + // token, so we must NOT add +1 here (#147). ( true, vec![*context_ids .last() .ok_or_else(|| anyhow!("empty context ids"))?], - (context_ids.len() + 1).max(1), + (context_ids.len() + initial_cache_len).max(1), ) } else { let step_input_ids = context_ids.clone(); @@ -3047,6 +3132,10 @@ pub fn run_prompt(config: &ModelConfig, prompt: &str) -> Result { bail!("prompt encoding produced no token IDs"); } + // Track initial KV-cache padding length so the decode loop can account + // for it in the attention mask (#147). + let mut initial_cache_len: usize = 0; + if cache_enabled { // Batch prefill: ingest the entire prompt in a single forward pass. let prefill_started = Instant::now(); @@ -3057,7 +3146,7 @@ pub fn run_prompt(config: &ModelConfig, prompt: &str) -> Result { // attention will concatenate past (length 1) + current (length N), // producing a key sequence of N+1. The attention mask must match that // total, otherwise we get a broadcast shape error (#136). - let initial_cache_len: usize = if layout.use_cache.is_none() && !cache_state.is_empty() { + initial_cache_len = if layout.use_cache.is_none() && !cache_state.is_empty() { // Cache IS included during prefill — get its actual past-axis size. cache_state .values() @@ -3110,13 +3199,16 @@ pub fn run_prompt(config: &ModelConfig, prompt: &str) -> Result { for step in 0..config.max_new_tokens.max(1) { let step_started = Instant::now(); let (decode_with_cache, step_input_ids, attention_len) = if cache_enabled { - // attention_len = past KV-cache entries + 1 current decode token (#114). + // attention_len = past KV-cache entries (context_ids tracks total + // processed tokens; initial_cache_len accounts for prior cache + // padding). The model internally handles the current decode + // token, so we must NOT add +1 here (#147). ( true, vec![*context_ids .last() .ok_or_else(|| anyhow!("empty context ids"))?], - (context_ids.len() + 1).max(1), + (context_ids.len() + initial_cache_len).max(1), ) } else { let step_input_ids = context_ids.clone();