Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 58 additions & 4 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,12 +478,13 @@ use anyhow::{anyhow, bail, Context, Result};
use clap::{Parser, ValueEnum};
use core_engine::agent::Agent;
use core_engine::{
CoverageBaseline, EvidencePointer, Finding, FindingSeverity, LiveFailureReasonCount,
LiveFallbackDecision, LiveRunMetrics, RunReport,
classify_capability, CoverageBaseline, EvidencePointer, Finding, FindingSeverity,
LiveFailureReasonCount, LiveFallbackDecision, LiveRunMetrics, ModelCapabilityReport,
ModelCapabilityTier, RunReport,
};
use cyber_tools::{ToolRegistry, ToolSpec};
use inference_bridge::onnx_vitis::{inspect_runtime_compatibility, RuntimeCompatibilitySeverity};
use inference_bridge::{ModelConfig, OnnxVitisEngine, VitisEpConfig};
use inference_bridge::{probe_model_capability, ModelConfig, OnnxVitisEngine, VitisEpConfig};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
Expand Down Expand Up @@ -579,6 +580,24 @@ enum OutputMode {
Full,
}

#[derive(Debug, Clone, Copy, ValueEnum, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
enum CapabilityOverride {
Basic,
Moderate,
Strong,
}

impl CapabilityOverride {
fn to_tier(self) -> ModelCapabilityTier {
match self {
Self::Basic => ModelCapabilityTier::Basic,
Self::Moderate => ModelCapabilityTier::Moderate,
Self::Strong => ModelCapabilityTier::Strong,
}
}
}

#[derive(Debug, Parser, Clone)]
#[command(name = "wraithrun", about = "Local-first cyber investigation runtime")]
struct Cli {
Expand Down Expand Up @@ -728,6 +747,9 @@ struct Cli {

#[arg(long)]
vitis_cache_key: Option<String>,

#[arg(long, value_enum)]
capability_override: Option<CapabilityOverride>,
}

#[derive(Debug, Clone, Default, Deserialize)]
Expand Down Expand Up @@ -788,6 +810,7 @@ struct RuntimeConfig {
vitis_config: Option<String>,
vitis_cache_dir: Option<String>,
vitis_cache_key: Option<String>,
capability_override: Option<CapabilityOverride>,
}

#[derive(Debug, Serialize)]
Expand Down Expand Up @@ -1104,6 +1127,7 @@ impl RuntimeConfig {
vitis_config: None,
vitis_cache_dir: None,
vitis_cache_key: None,
capability_override: None,
}
}

Expand Down Expand Up @@ -2437,6 +2461,9 @@ fn apply_cli_overrides(runtime: &mut RuntimeConfig, cli: &Cli) {
if let Some(vitis_cache_key) = &cli.vitis_cache_key {
runtime.vitis_cache_key = Some(vitis_cache_key.clone());
}
if let Some(capability_override) = cli.capability_override {
runtime.capability_override = Some(capability_override);
}
}

fn apply_fragment_with_source(
Expand Down Expand Up @@ -2628,6 +2655,9 @@ fn apply_cli_overrides_with_source(
runtime.vitis_cache_key = Some(vitis_cache_key.clone());
sources.vitis_cache_key = "cli --vitis-cache-key".to_string();
}
if let Some(capability_override) = cli.capability_override {
runtime.capability_override = Some(capability_override);
}
}

fn validate_runtime_config(config: &RuntimeConfig) -> Result<()> {
Expand Down Expand Up @@ -5355,9 +5385,31 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result<RunRep
vitis_config,
};

// Determine capability tier: override > probe > default.
let (tier, capability_report) = if let Some(cap_override) = runtime.capability_override {
let tier = cap_override.to_tier();
let probe = probe_model_capability(&model_config);
let mut report = ModelCapabilityReport::from_probe(&probe, tier);
report.r#override = true;
(tier, Some(report))
} else if dry_run {
(ModelCapabilityTier::Strong, None)
} else {
let probe = probe_model_capability(&model_config);
let tier = classify_capability(&probe);
let report = ModelCapabilityReport::from_probe(&probe, tier);
(tier, Some(report))
};

let brain = OnnxVitisEngine::new(model_config);
let tools = ToolRegistry::with_default_tools();
let mut agent = Agent::new(brain, tools).with_max_steps(runtime.max_steps);
let mut agent = Agent::new(brain, tools)
.with_max_steps(runtime.max_steps)
.with_capability_tier(tier);

if let Some(report) = capability_report {
agent = agent.with_model_capability_report(report);
}

if let Some(baseline_bundle) = runtime.baseline_bundle.as_deref() {
let coverage_baseline = load_coverage_baseline_from_bundle(baseline_bundle)?;
Expand Down Expand Up @@ -5565,6 +5617,7 @@ mod tests {
vitis_config: None,
vitis_cache_dir: None,
vitis_cache_key: None,
capability_override: None,
}
}

Expand All @@ -5573,6 +5626,7 @@ mod tests {
task: "Check suspicious listener ports and summarize risk".to_string(),
case_id: Some("CASE-2026-0001".to_string()),
max_severity: Some(FindingSeverity::Medium),
model_capability: None,
live_fallback_decision: None,
run_timing: None,
live_run_metrics: None,
Expand Down
78 changes: 62 additions & 16 deletions core_engine/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@ use cyber_tools::ToolRegistry;
use inference_bridge::InferenceEngine;

use crate::{
deduplicate_findings, derive_findings, extract_tag, max_severity, quality_checked_final_answer,
sort_findings, AgentTurn, CoverageBaseline, RunReport, RunTimingMetrics, ToolCall,
basic_tier_summary, deduplicate_findings, derive_findings, extract_tag, max_severity,
quality_checked_final_answer, sort_findings, AgentTurn, CoverageBaseline,
ModelCapabilityReport, ModelCapabilityTier, RunReport, RunTimingMetrics, ToolCall,
};

pub struct Agent<B: InferenceEngine> {
brain: B,
tools: ToolRegistry,
max_steps: usize,
coverage_baseline: Option<CoverageBaseline>,
capability_tier: ModelCapabilityTier,
model_capability_report: Option<ModelCapabilityReport>,
}

impl<B: InferenceEngine> Agent<B> {
Expand All @@ -25,6 +28,8 @@ impl<B: InferenceEngine> Agent<B> {
tools,
max_steps: 8,
coverage_baseline: None,
capability_tier: ModelCapabilityTier::Strong,
model_capability_report: None,
}
}

Expand All @@ -42,6 +47,16 @@ impl<B: InferenceEngine> Agent<B> {
self
}

pub fn with_capability_tier(mut self, tier: ModelCapabilityTier) -> Self {
self.capability_tier = tier;
self
}

pub fn with_model_capability_report(mut self, report: ModelCapabilityReport) -> Self {
self.model_capability_report = Some(report);
self
}

fn apply_coverage_baseline_to_call(&self, call: &mut ToolCall) {
let Some(coverage_baseline) = self.coverage_baseline.as_ref() else {
return;
Expand Down Expand Up @@ -125,29 +140,49 @@ impl<B: InferenceEngine> Agent<B> {
});
}

// Phase 2: LLM synthesis — analyze evidence and produce findings.
let evidence_summary = build_evidence_summary(&turns);
let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary);

let output = self.brain.generate(&synthesis_prompt).await?;
let first_token_latency_ms = Some(elapsed_ms_since(run_started_at));
info!(output = %output, "agent synthesis output");

let raw_final_answer = extract_tag(&output, "final").unwrap_or(output);

// Structured findings from tool observations (rule-based extraction).
let raw_findings = derive_findings(&turns, &raw_final_answer);
// Phase 2: synthesis — behavior depends on capability tier.
let raw_findings = derive_findings(&turns, "");
let mut findings = deduplicate_findings(raw_findings);
sort_findings(&mut findings);

// Quality-check LLM output; replace with deterministic summary if low quality.
let final_answer = quality_checked_final_answer(&raw_final_answer, &findings);
let (final_answer, first_token_latency_ms) = match self.capability_tier {
ModelCapabilityTier::Basic => {
// Skip LLM entirely; build deterministic summary from findings.
debug!("Basic tier: skipping LLM synthesis");
let answer = basic_tier_summary(&findings);
(answer, None)
}
ModelCapabilityTier::Moderate => {
// Call LLM with reduced evidence (top-5 observations).
let evidence_summary = build_evidence_summary_limited(&turns, 5);
let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary);
let output = self.brain.generate(&synthesis_prompt).await?;
let latency = Some(elapsed_ms_since(run_started_at));
info!(output = %output, "agent synthesis output (moderate)");
let raw = extract_tag(&output, "final").unwrap_or(output);
let answer = quality_checked_final_answer(&raw, &findings);
(answer, latency)
}
ModelCapabilityTier::Strong => {
// Full evidence, full synthesis.
let evidence_summary = build_evidence_summary(&turns);
let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary);
let output = self.brain.generate(&synthesis_prompt).await?;
let latency = Some(elapsed_ms_since(run_started_at));
info!(output = %output, "agent synthesis output (strong)");
let raw = extract_tag(&output, "final").unwrap_or(output);
let answer = quality_checked_final_answer(&raw, &findings);
(answer, latency)
}
};

let report_max_severity = max_severity(&findings);

Ok(RunReport {
task: task.to_string(),
case_id: None,
max_severity: report_max_severity,
model_capability: self.model_capability_report.clone(),
live_fallback_decision: None,
run_timing: Some(build_run_timing_metrics(
run_started_at,
Expand Down Expand Up @@ -249,8 +284,18 @@ fn has_word(text: &str, word: &str) -> bool {

/// Build a concise evidence summary from tool observations for LLM synthesis.
fn build_evidence_summary(turns: &[AgentTurn]) -> String {
build_evidence_summary_limited(turns, usize::MAX)
}

/// Build an evidence summary limited to the first `max_turns` observations.
/// Used by Moderate tier to reduce prompt size.
fn build_evidence_summary_limited(turns: &[AgentTurn], max_turns: usize) -> String {
let mut summary = String::new();
let mut count = 0;
for turn in turns {
if count >= max_turns {
break;
}
let tool_name = turn
.tool_call
.as_ref()
Expand All @@ -265,6 +310,7 @@ fn build_evidence_summary(turns: &[AgentTurn]) -> String {
obs_str
};
summary.push_str(&format!("[{tool_name}] {truncated}\n\n"));
count += 1;
}
}
summary
Expand Down
Loading
Loading