Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 184 additions & 2 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ use core_engine::{
};
use cyber_tools::{ToolRegistry, ToolSpec};
use inference_bridge::onnx_vitis::{inspect_runtime_compatibility, RuntimeCompatibilitySeverity};
use inference_bridge::{probe_model_capability, ModelConfig, OnnxVitisEngine, VitisEpConfig};
use inference_bridge::{
backend::ProviderRegistry, probe_model_capability, ModelConfig, OnnxVitisEngine, VitisEpConfig,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sha2::{Digest, Sha256};
Expand Down Expand Up @@ -774,6 +776,11 @@ struct Cli {
#[arg(long)]
vitis_cache_key: Option<String>,

/// Override inference backend selection (e.g. "cpu", "vitis").
/// Use "auto" or omit to auto-select the highest-priority available backend.
#[arg(long, value_name = "NAME")]
backend: Option<String>,

#[arg(long, value_enum)]
capability_override: Option<CapabilityOverride>,
}
Expand Down Expand Up @@ -802,6 +809,13 @@ struct SettingsFragment {
vitis_config: Option<String>,
vitis_cache_dir: Option<String>,
vitis_cache_key: Option<String>,
backend: Option<String>,
}

#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default, deny_unknown_fields)]
struct InferenceConfig {
backend: Option<String>,
}

#[derive(Debug, Clone, Default, Deserialize)]
Expand All @@ -810,6 +824,8 @@ struct FileConfig {
#[serde(flatten)]
defaults: SettingsFragment,
profiles: HashMap<String, SettingsFragment>,
#[serde(default)]
inference: InferenceConfig,
}

#[derive(Debug, Clone)]
Expand All @@ -836,6 +852,7 @@ struct RuntimeConfig {
vitis_config: Option<String>,
vitis_cache_dir: Option<String>,
vitis_cache_key: Option<String>,
backend: Option<String>,
capability_override: Option<CapabilityOverride>,
tools_dir: Option<PathBuf>,
allowed_plugins: Vec<String>,
Expand Down Expand Up @@ -865,6 +882,7 @@ struct RuntimeConfigView {
vitis_config: Option<String>,
vitis_cache_dir: Option<String>,
vitis_cache_key: Option<String>,
backend: Option<String>,
}

#[derive(Debug, Clone, Serialize)]
Expand All @@ -891,6 +909,7 @@ struct RuntimeConfigSources {
vitis_config: String,
vitis_cache_dir: String,
vitis_cache_key: String,
backend: String,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -1088,6 +1107,8 @@ struct DoctorSummaryView {
struct DoctorReportView<'a> {
summary: DoctorSummaryView,
checks: &'a [DoctorCheck],
#[serde(skip_serializing_if = "<[DoctorBackendEntry]>::is_empty")]
backends: &'a [DoctorBackendEntry],
}

#[derive(Debug, Default, Serialize)]
Expand Down Expand Up @@ -1155,6 +1176,7 @@ impl RuntimeConfig {
vitis_config: None,
vitis_cache_dir: None,
vitis_cache_key: None,
backend: None,
capability_override: None,
tools_dir: None,
allowed_plugins: Vec::new(),
Expand Down Expand Up @@ -1222,6 +1244,9 @@ impl RuntimeConfig {
if let Some(vitis_cache_key) = &fragment.vitis_cache_key {
self.vitis_cache_key = Some(vitis_cache_key.clone());
}
if let Some(backend) = &fragment.backend {
self.backend = Some(backend.clone());
}
}
}

Expand Down Expand Up @@ -1250,6 +1275,7 @@ impl RuntimeConfigSources {
vitis_config: "default".to_string(),
vitis_cache_dir: "default".to_string(),
vitis_cache_key: "default".to_string(),
backend: "default".to_string(),
}
}
}
Expand Down Expand Up @@ -1338,6 +1364,23 @@ struct DoctorCheck {
#[derive(Debug, Default, Serialize)]
struct DoctorReport {
checks: Vec<DoctorCheck>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
backends: Vec<DoctorBackendEntry>,
}

#[derive(Debug, Clone, Serialize)]
struct DoctorBackendEntry {
name: String,
priority: u32,
available: bool,
diagnostics: Vec<DoctorBackendDiag>,
}

#[derive(Debug, Clone, Serialize)]
struct DoctorBackendDiag {
status: DoctorStatus,
check: String,
detail: String,
}

impl DoctorReport {
Expand Down Expand Up @@ -2215,6 +2258,13 @@ fn merge_sources(
if let Some(file_config) = file_config {
resolved.apply_fragment(&file_config.defaults);

// Apply [inference] section if present.
if let Some(backend) = &file_config.inference.backend {
if backend != "auto" {
resolved.backend = Some(backend.clone());
}
}

if let Some(profile_name) = profile.as_deref() {
if let Some(profile_settings) = lookup_profile(&file_config.profiles, profile_name) {
resolved.apply_fragment(profile_settings);
Expand Down Expand Up @@ -2273,6 +2323,16 @@ fn merge_sources_with_explanation(
.unwrap_or_else(|| "config defaults".to_string());
apply_fragment_with_source(&mut resolved, &mut sources, &file_config.defaults, &source);

// Apply [inference] section if present.
if let Some(backend) = &file_config.inference.backend {
if backend != "auto" {
resolved.backend = Some(backend.clone());
sources.backend = file_config_path
.map(|path| format!("config [inference] ({})", path.display()))
.unwrap_or_else(|| "config [inference]".to_string());
}
}

if let Some(profile_name) = profile.as_deref() {
if let Some(profile_settings) = lookup_profile(&file_config.profiles, profile_name) {
let source = file_config_path
Expand Down Expand Up @@ -2477,6 +2537,7 @@ fn env_settings_fragment() -> Result<SettingsFragment> {
vitis_config: read_env_string("WRAITHRUN_VITIS_CONFIG")?,
vitis_cache_dir: read_env_string("WRAITHRUN_VITIS_CACHE_DIR")?,
vitis_cache_key: read_env_string("WRAITHRUN_VITIS_CACHE_KEY")?,
backend: read_env_string("WRAITHRUN_BACKEND")?,
})
}

Expand Down Expand Up @@ -2550,6 +2611,9 @@ fn apply_cli_overrides(runtime: &mut RuntimeConfig, cli: &Cli) {
if let Some(vitis_cache_key) = &cli.vitis_cache_key {
runtime.vitis_cache_key = Some(vitis_cache_key.clone());
}
if let Some(backend) = &cli.backend {
runtime.backend = Some(backend.clone());
}
if let Some(capability_override) = cli.capability_override {
runtime.capability_override = Some(capability_override);
}
Expand Down Expand Up @@ -2651,6 +2715,10 @@ fn apply_fragment_with_source(
runtime.vitis_cache_key = Some(vitis_cache_key.clone());
sources.vitis_cache_key = source.to_string();
}
if let Some(backend) = &fragment.backend {
runtime.backend = Some(backend.clone());
sources.backend = source.to_string();
}
}

fn apply_cli_overrides_with_source(
Expand Down Expand Up @@ -2750,6 +2818,10 @@ fn apply_cli_overrides_with_source(
runtime.vitis_cache_key = Some(vitis_cache_key.clone());
sources.vitis_cache_key = "cli --vitis-cache-key".to_string();
}
if let Some(backend) = &cli.backend {
runtime.backend = Some(backend.clone());
sources.backend = "cli --backend".to_string();
}
if let Some(capability_override) = cli.capability_override {
runtime.capability_override = Some(capability_override);
}
Expand Down Expand Up @@ -3733,6 +3805,7 @@ impl RuntimeConfigView {
vitis_config: runtime.vitis_config.clone(),
vitis_cache_dir: runtime.vitis_cache_dir.clone(),
vitis_cache_key: runtime.vitis_cache_key.clone(),
backend: runtime.backend.clone(),
}
}
}
Expand Down Expand Up @@ -3925,6 +3998,59 @@ fn run_doctor(cli: &Cli) -> DoctorReport {
}
}

// Inference backend diagnostics.
let registry = ProviderRegistry::discover();
for bd in registry.diagnose_all() {
let status = if bd.info.available {
DoctorStatus::Pass
} else {
DoctorStatus::Warn
};
let availability = if bd.info.available {
"available"
} else {
"not available"
};
report.push(
status,
"inference-backend",
format!(
"{} (priority {}) — {}",
bd.info.name, bd.info.priority, availability
),
);
let mut backend_diags = Vec::new();
for diag in &bd.diagnostics {
let diag_status = match diag.severity {
inference_bridge::backend::DiagnosticSeverity::Pass => {
DoctorStatus::Pass
}
inference_bridge::backend::DiagnosticSeverity::Warn => {
DoctorStatus::Warn
}
inference_bridge::backend::DiagnosticSeverity::Fail => {
DoctorStatus::Fail
}
};
report.push(
diag_status,
"inference-backend",
format!(" {} — {}: {}", bd.info.name, diag.check, diag.message),
);
backend_diags.push(DoctorBackendDiag {
status: diag_status,
check: diag.check.clone(),
detail: diag.message.clone(),
});
}
report.backends.push(DoctorBackendEntry {
name: bd.info.name,
priority: bd.info.priority,
available: bd.info.available,
diagnostics: backend_diags,
});
}

// Plugin tools check.
if !runtime.allowed_plugins.is_empty() {
let tools_dir = runtime
Expand Down Expand Up @@ -4492,6 +4618,7 @@ fn render_doctor_report_json(report: &DoctorReport) -> Result<String> {
fail: fail_count,
},
checks: &report.checks,
backends: &report.backends,
};
render_json_with_contract(&view)
}
Expand Down Expand Up @@ -5744,9 +5871,14 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result<RunRep
validate_live_runtime_preflight(runtime)?;
}

// Resolve inference backend.
let registry = ProviderRegistry::discover();
let resolved_backend_name = resolve_backend(&registry, runtime)?;

let vitis_config = build_vitis_config(runtime);
let (backend_override, backend_config) = match vitis_config {
Some(cfg) => (Some("vitis".to_string()), cfg.into_backend_config()),
None if runtime.backend.is_some() => (runtime.backend.clone(), Default::default()),
None => (None, Default::default()),
};
let model_config = ModelConfig {
Expand All @@ -5759,6 +5891,8 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result<RunRep
backend_config,
};

tracing::info!(backend = %resolved_backend_name, "Selected inference backend");

// Determine capability tier: override > probe > default.
let (tier, capability_report) = if let Some(cap_override) = runtime.capability_override {
let tier = cap_override.to_tier();
Expand Down Expand Up @@ -5800,7 +5934,50 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result<RunRep
agent = agent.with_coverage_baseline(coverage_baseline);
}

agent.run(&runtime.task).await
let mut report = agent.run(&runtime.task).await?;
report.backend = Some(resolved_backend_name);
Ok(report)
}

/// Resolve which backend to use based on `--backend` flag or auto-select.
fn resolve_backend(registry: &ProviderRegistry, runtime: &RuntimeConfig) -> Result<String> {
if let Some(requested) = &runtime.backend {
if requested.eq_ignore_ascii_case("auto") {
// Explicit "auto" — fall through to auto-select.
} else {
// User asked for a specific backend.
match registry.get(requested) {
Some(backend) => {
if backend.is_available() {
return Ok(backend.name().to_string());
}
bail!(
"{} backend is registered but not available on this system",
backend.name()
);
}
None => {
let available = registry
.list()
.iter()
.map(|p| p.name.clone())
.collect::<Vec<_>>()
.join(", ");
bail!(
"\"{}\" backend not found; available backends: {}",
requested,
available
);
}
}
}
}

// Auto-select: pick highest-priority available backend.
match registry.best_available() {
Some(backend) => Ok(backend.name().to_string()),
None => bail!("no inference backend available"),
}
}

fn append_live_fallback_finding(report: &mut RunReport, decision: &LiveFallbackDecision) {
Expand Down Expand Up @@ -6006,6 +6183,7 @@ mod tests {
vitis_config: None,
vitis_cache_dir: None,
vitis_cache_key: None,
backend: None,
capability_override: None,
tools_dir: None,
allowed_plugins: vec![],
Expand All @@ -6017,6 +6195,7 @@ mod tests {
task: "Check suspicious listener ports and summarize risk".to_string(),
case_id: Some("CASE-2026-0001".to_string()),
max_severity: Some(FindingSeverity::Medium),
backend: None,
model_capability: None,
live_fallback_decision: None,
run_timing: None,
Expand Down Expand Up @@ -6673,6 +6852,7 @@ mod tests {
..SettingsFragment::default()
},
)]),
..Default::default()
};

let env_overrides = SettingsFragment {
Expand Down Expand Up @@ -6911,6 +7091,7 @@ mod tests {
("team-default".to_string(), SettingsFragment::default()),
("incident-hotfix".to_string(), SettingsFragment::default()),
]),
..Default::default()
};

let rendered = render_profile_list(
Expand All @@ -6930,6 +7111,7 @@ mod tests {
let file_config = FileConfig {
defaults: SettingsFragment::default(),
profiles: HashMap::from([("incident-hotfix".to_string(), SettingsFragment::default())]),
..Default::default()
};

let rendered = render_profile_list_json(
Expand Down
Loading
Loading