From cef18980ce8a871b0f4218a7884aacb01f9ade74 Mon Sep 17 00:00:00 2001 From: Shreyas Sankpal Date: Fri, 3 Apr 2026 16:58:35 -0400 Subject: [PATCH] feat: v0.12.0 Model Capability Tiering (#75-#81) - Add ModelCapabilityProbe struct and probe_model_capability() in inference_bridge (#76): estimates params from file size, detects EP, measures smoke latency, extracts vocab size from tokenizer - Add ModelCapabilityTier enum (Basic/Moderate/Strong) with classify_capability() in core_engine (#77): const thresholds, tier = min(param_tier, latency_tier) - Agent adapts Phase 2 by tier (#78): Basic skips LLM entirely, Moderate uses reduced evidence, Strong uses full synthesis - Basic tier deterministic summary (#79): structured SUMMARY/FINDINGS/ RISK/ACTIONS format, byte-identical across runs - ModelCapabilityReport in RunReport JSON output (#80): tier, params, EP, latency, vocab_size; absent in dry-run mode - --capability-override CLI flag (#81): forces tier, skips probe, adds override:true in output - 18 new unit tests for classification boundaries, tier ordering, serialization, basic tier summary, and capability report Closes #75, closes #76, closes #77, closes #78, closes #79, closes #80, closes #81 --- cli/src/main.rs | 62 +++- core_engine/src/agent.rs | 78 ++++- core_engine/src/lib.rs | 328 ++++++++++++++++++ docs/schemas/examples/run-report.example.json | 8 + docs/schemas/run-report.schema.json | 33 ++ inference_bridge/src/lib.rs | 114 ++++++ 6 files changed, 603 insertions(+), 20 deletions(-) diff --git a/cli/src/main.rs b/cli/src/main.rs index 9fc9063..98ba383 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -478,12 +478,13 @@ use anyhow::{anyhow, bail, Context, Result}; use clap::{Parser, ValueEnum}; use core_engine::agent::Agent; use core_engine::{ - CoverageBaseline, EvidencePointer, Finding, FindingSeverity, LiveFailureReasonCount, - LiveFallbackDecision, LiveRunMetrics, RunReport, + classify_capability, CoverageBaseline, EvidencePointer, Finding, FindingSeverity, + LiveFailureReasonCount, LiveFallbackDecision, LiveRunMetrics, ModelCapabilityReport, + ModelCapabilityTier, RunReport, }; use cyber_tools::{ToolRegistry, ToolSpec}; use inference_bridge::onnx_vitis::{inspect_runtime_compatibility, RuntimeCompatibilitySeverity}; -use inference_bridge::{ModelConfig, OnnxVitisEngine, VitisEpConfig}; +use inference_bridge::{probe_model_capability, ModelConfig, OnnxVitisEngine, VitisEpConfig}; use serde::{Deserialize, Serialize}; use serde_json::Value; use sha2::{Digest, Sha256}; @@ -579,6 +580,24 @@ enum OutputMode { Full, } +#[derive(Debug, Clone, Copy, ValueEnum, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +enum CapabilityOverride { + Basic, + Moderate, + Strong, +} + +impl CapabilityOverride { + fn to_tier(self) -> ModelCapabilityTier { + match self { + Self::Basic => ModelCapabilityTier::Basic, + Self::Moderate => ModelCapabilityTier::Moderate, + Self::Strong => ModelCapabilityTier::Strong, + } + } +} + #[derive(Debug, Parser, Clone)] #[command(name = "wraithrun", about = "Local-first cyber investigation runtime")] struct Cli { @@ -728,6 +747,9 @@ struct Cli { #[arg(long)] vitis_cache_key: Option, + + #[arg(long, value_enum)] + capability_override: Option, } #[derive(Debug, Clone, Default, Deserialize)] @@ -788,6 +810,7 @@ struct RuntimeConfig { vitis_config: Option, vitis_cache_dir: Option, vitis_cache_key: Option, + capability_override: Option, } #[derive(Debug, Serialize)] @@ -1104,6 +1127,7 @@ impl RuntimeConfig { vitis_config: None, vitis_cache_dir: None, vitis_cache_key: None, + capability_override: None, } } @@ -2437,6 +2461,9 @@ fn apply_cli_overrides(runtime: &mut RuntimeConfig, cli: &Cli) { if let Some(vitis_cache_key) = &cli.vitis_cache_key { runtime.vitis_cache_key = Some(vitis_cache_key.clone()); } + if let Some(capability_override) = cli.capability_override { + runtime.capability_override = Some(capability_override); + } } fn apply_fragment_with_source( @@ -2628,6 +2655,9 @@ fn apply_cli_overrides_with_source( runtime.vitis_cache_key = Some(vitis_cache_key.clone()); sources.vitis_cache_key = "cli --vitis-cache-key".to_string(); } + if let Some(capability_override) = cli.capability_override { + runtime.capability_override = Some(capability_override); + } } fn validate_runtime_config(config: &RuntimeConfig) -> Result<()> { @@ -5355,9 +5385,31 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result probe > default. + let (tier, capability_report) = if let Some(cap_override) = runtime.capability_override { + let tier = cap_override.to_tier(); + let probe = probe_model_capability(&model_config); + let mut report = ModelCapabilityReport::from_probe(&probe, tier); + report.r#override = true; + (tier, Some(report)) + } else if dry_run { + (ModelCapabilityTier::Strong, None) + } else { + let probe = probe_model_capability(&model_config); + let tier = classify_capability(&probe); + let report = ModelCapabilityReport::from_probe(&probe, tier); + (tier, Some(report)) + }; + let brain = OnnxVitisEngine::new(model_config); let tools = ToolRegistry::with_default_tools(); - let mut agent = Agent::new(brain, tools).with_max_steps(runtime.max_steps); + let mut agent = Agent::new(brain, tools) + .with_max_steps(runtime.max_steps) + .with_capability_tier(tier); + + if let Some(report) = capability_report { + agent = agent.with_model_capability_report(report); + } if let Some(baseline_bundle) = runtime.baseline_bundle.as_deref() { let coverage_baseline = load_coverage_baseline_from_bundle(baseline_bundle)?; @@ -5565,6 +5617,7 @@ mod tests { vitis_config: None, vitis_cache_dir: None, vitis_cache_key: None, + capability_override: None, } } @@ -5573,6 +5626,7 @@ mod tests { task: "Check suspicious listener ports and summarize risk".to_string(), case_id: Some("CASE-2026-0001".to_string()), max_severity: Some(FindingSeverity::Medium), + model_capability: None, live_fallback_decision: None, run_timing: None, live_run_metrics: None, diff --git a/core_engine/src/agent.rs b/core_engine/src/agent.rs index 3bc8757..a38886d 100644 --- a/core_engine/src/agent.rs +++ b/core_engine/src/agent.rs @@ -7,8 +7,9 @@ use cyber_tools::ToolRegistry; use inference_bridge::InferenceEngine; use crate::{ - deduplicate_findings, derive_findings, extract_tag, max_severity, quality_checked_final_answer, - sort_findings, AgentTurn, CoverageBaseline, RunReport, RunTimingMetrics, ToolCall, + basic_tier_summary, deduplicate_findings, derive_findings, extract_tag, max_severity, + quality_checked_final_answer, sort_findings, AgentTurn, CoverageBaseline, + ModelCapabilityReport, ModelCapabilityTier, RunReport, RunTimingMetrics, ToolCall, }; pub struct Agent { @@ -16,6 +17,8 @@ pub struct Agent { tools: ToolRegistry, max_steps: usize, coverage_baseline: Option, + capability_tier: ModelCapabilityTier, + model_capability_report: Option, } impl Agent { @@ -25,6 +28,8 @@ impl Agent { tools, max_steps: 8, coverage_baseline: None, + capability_tier: ModelCapabilityTier::Strong, + model_capability_report: None, } } @@ -42,6 +47,16 @@ impl Agent { self } + pub fn with_capability_tier(mut self, tier: ModelCapabilityTier) -> Self { + self.capability_tier = tier; + self + } + + pub fn with_model_capability_report(mut self, report: ModelCapabilityReport) -> Self { + self.model_capability_report = Some(report); + self + } + fn apply_coverage_baseline_to_call(&self, call: &mut ToolCall) { let Some(coverage_baseline) = self.coverage_baseline.as_ref() else { return; @@ -125,29 +140,49 @@ impl Agent { }); } - // Phase 2: LLM synthesis — analyze evidence and produce findings. - let evidence_summary = build_evidence_summary(&turns); - let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary); - - let output = self.brain.generate(&synthesis_prompt).await?; - let first_token_latency_ms = Some(elapsed_ms_since(run_started_at)); - info!(output = %output, "agent synthesis output"); - - let raw_final_answer = extract_tag(&output, "final").unwrap_or(output); - - // Structured findings from tool observations (rule-based extraction). - let raw_findings = derive_findings(&turns, &raw_final_answer); + // Phase 2: synthesis — behavior depends on capability tier. + let raw_findings = derive_findings(&turns, ""); let mut findings = deduplicate_findings(raw_findings); sort_findings(&mut findings); - // Quality-check LLM output; replace with deterministic summary if low quality. - let final_answer = quality_checked_final_answer(&raw_final_answer, &findings); + let (final_answer, first_token_latency_ms) = match self.capability_tier { + ModelCapabilityTier::Basic => { + // Skip LLM entirely; build deterministic summary from findings. + debug!("Basic tier: skipping LLM synthesis"); + let answer = basic_tier_summary(&findings); + (answer, None) + } + ModelCapabilityTier::Moderate => { + // Call LLM with reduced evidence (top-5 observations). + let evidence_summary = build_evidence_summary_limited(&turns, 5); + let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary); + let output = self.brain.generate(&synthesis_prompt).await?; + let latency = Some(elapsed_ms_since(run_started_at)); + info!(output = %output, "agent synthesis output (moderate)"); + let raw = extract_tag(&output, "final").unwrap_or(output); + let answer = quality_checked_final_answer(&raw, &findings); + (answer, latency) + } + ModelCapabilityTier::Strong => { + // Full evidence, full synthesis. + let evidence_summary = build_evidence_summary(&turns); + let synthesis_prompt = format_synthesis_prompt(task, &evidence_summary); + let output = self.brain.generate(&synthesis_prompt).await?; + let latency = Some(elapsed_ms_since(run_started_at)); + info!(output = %output, "agent synthesis output (strong)"); + let raw = extract_tag(&output, "final").unwrap_or(output); + let answer = quality_checked_final_answer(&raw, &findings); + (answer, latency) + } + }; + let report_max_severity = max_severity(&findings); Ok(RunReport { task: task.to_string(), case_id: None, max_severity: report_max_severity, + model_capability: self.model_capability_report.clone(), live_fallback_decision: None, run_timing: Some(build_run_timing_metrics( run_started_at, @@ -249,8 +284,18 @@ fn has_word(text: &str, word: &str) -> bool { /// Build a concise evidence summary from tool observations for LLM synthesis. fn build_evidence_summary(turns: &[AgentTurn]) -> String { + build_evidence_summary_limited(turns, usize::MAX) +} + +/// Build an evidence summary limited to the first `max_turns` observations. +/// Used by Moderate tier to reduce prompt size. +fn build_evidence_summary_limited(turns: &[AgentTurn], max_turns: usize) -> String { let mut summary = String::new(); + let mut count = 0; for turn in turns { + if count >= max_turns { + break; + } let tool_name = turn .tool_call .as_ref() @@ -265,6 +310,7 @@ fn build_evidence_summary(turns: &[AgentTurn]) -> String { obs_str }; summary.push_str(&format!("[{tool_name}] {truncated}\n\n")); + count += 1; } } summary diff --git a/core_engine/src/lib.rs b/core_engine/src/lib.rs index 6beabd0..add0ac8 100644 --- a/core_engine/src/lib.rs +++ b/core_engine/src/lib.rs @@ -2,9 +2,72 @@ pub mod agent; use std::collections::HashSet; +use inference_bridge::ModelCapabilityProbe; use serde::{Deserialize, Serialize, Serializer}; use serde_json::Value; +// ── Capability tiering thresholds (const, easy to tune) ── + +/// Models below this parameter count (billions) are classified as Basic. +const PARAM_BASIC_CEILING_B: f32 = 2.0; +/// Models above this parameter count (billions) are classified as Strong. +const PARAM_STRONG_FLOOR_B: f32 = 10.0; +/// Latency above this (ms/tok) demotes to Basic. +const LATENCY_BASIC_FLOOR_MS: u64 = 200; +/// Latency below this (ms/tok) promotes to Strong. +const LATENCY_STRONG_CEILING_MS: u64 = 50; + +/// Model capability tier that determines agent behavior in Phase 2. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] +#[serde(rename_all = "lowercase")] +pub enum ModelCapabilityTier { + Basic, + Moderate, + Strong, +} + +impl ModelCapabilityTier { + pub fn token(self) -> &'static str { + match self { + Self::Basic => "basic", + Self::Moderate => "moderate", + Self::Strong => "strong", + } + } +} + +impl std::fmt::Display for ModelCapabilityTier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.token()) + } +} + +/// Classify a model's capability probe signals into a tier. +/// +/// Final tier = min(param_tier, latency_tier). +/// A 13B model on a slow CPU is Moderate (latency-constrained). +/// A 1B model on a fast GPU is Basic (param-constrained). +pub fn classify_capability(probe: &ModelCapabilityProbe) -> ModelCapabilityTier { + let param_tier = if probe.estimated_param_billions < PARAM_BASIC_CEILING_B { + ModelCapabilityTier::Basic + } else if probe.estimated_param_billions > PARAM_STRONG_FLOOR_B { + ModelCapabilityTier::Strong + } else { + ModelCapabilityTier::Moderate + }; + + let latency_tier = if probe.smoke_latency_ms > LATENCY_BASIC_FLOOR_MS { + ModelCapabilityTier::Basic + } else if probe.smoke_latency_ms < LATENCY_STRONG_CEILING_MS { + ModelCapabilityTier::Strong + } else { + ModelCapabilityTier::Moderate + }; + + // min() works because of the PartialOrd derive: Basic < Moderate < Strong. + param_tier.min(latency_tier) +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolCall { pub tool: String, @@ -132,6 +195,8 @@ pub struct RunReport { #[serde(default, skip_serializing_if = "Option::is_none")] pub max_severity: Option, #[serde(default, skip_serializing_if = "Option::is_none")] + pub model_capability: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub live_fallback_decision: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub run_timing: Option, @@ -713,6 +778,76 @@ fn deterministic_summary(findings: &[Finding]) -> String { ) } +/// Build a rich deterministic summary for Basic-tier runs (no LLM). +/// +/// Format follows the spec from issue #79 — byte-identical across runs +/// with the same findings. +pub fn basic_tier_summary(findings: &[Finding]) -> String { + if findings.is_empty() { + return "SUMMARY: 0 findings detected. Maximum severity: info.\nFINDINGS:\n(none)\nRISK: info\nACTIONS:\n(none)".to_string(); + } + + let max_sev = findings + .iter() + .map(|f| f.severity) + .max() + .unwrap_or(FindingSeverity::Info); + + let mut out = format!( + "SUMMARY: {} findings detected. Maximum severity: {}.\nFINDINGS:\n", + findings.len(), + max_sev.token() + ); + + for (i, f) in findings.iter().enumerate() { + out.push_str(&format!( + "{}. {} [{}] — {}\n", + i + 1, + f.title, + f.severity.token(), + f.recommended_action + )); + } + + out.push_str(&format!("RISK: {}\nACTIONS:\n", max_sev.token())); + + for (i, f) in findings.iter().enumerate() { + out.push_str(&format!("{}. {}\n", i + 1, f.recommended_action)); + } + + // Remove trailing newline for clean output. + if out.ends_with('\n') { + out.truncate(out.len() - 1); + } + + out +} + +/// Model capability report for JSON output (#80). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelCapabilityReport { + pub tier: ModelCapabilityTier, + pub estimated_params_b: f32, + pub execution_provider: String, + pub smoke_latency_ms: u64, + pub vocab_size: usize, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub r#override: bool, +} + +impl ModelCapabilityReport { + pub fn from_probe(probe: &ModelCapabilityProbe, tier: ModelCapabilityTier) -> Self { + Self { + tier, + estimated_params_b: probe.estimated_param_billions, + execution_provider: probe.execution_provider.clone(), + smoke_latency_ms: probe.smoke_latency_ms, + vocab_size: probe.vocab_size, + r#override: false, + } + } +} + fn confidence_from_count(base: f32, count: u64, slope: f32, ceiling: f32) -> f32 { let raw = (base + (count as f32 * slope)).min(ceiling); (raw * 100.0).round() / 100.0 @@ -1097,4 +1232,197 @@ mod tests { assert!(FindingSeverity::Medium > FindingSeverity::Low); assert!(FindingSeverity::Low > FindingSeverity::Info); } + + // ── Capability tiering tests (#77) ── + + use super::{ + basic_tier_summary, classify_capability, ModelCapabilityReport, ModelCapabilityTier, + }; + use inference_bridge::ModelCapabilityProbe; + + #[test] + fn classify_small_model_as_basic() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 1.2, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 80, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Basic); + } + + #[test] + fn classify_medium_model_moderate_latency_as_moderate() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 7.0, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 120, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Moderate); + } + + #[test] + fn classify_large_model_fast_gpu_as_strong() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 13.0, + execution_provider: "CUDAExecutionProvider".to_string(), + smoke_latency_ms: 30, + vocab_size: 128256, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Strong); + } + + #[test] + fn classify_large_model_slow_cpu_as_basic() { + // 13B model but 250ms/tok latency → latency-constrained → Basic. + let probe = ModelCapabilityProbe { + estimated_param_billions: 13.0, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 250, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Basic); + } + + #[test] + fn classify_small_model_fast_gpu_as_basic() { + // 1B model on fast GPU → param-constrained → Basic. + let probe = ModelCapabilityProbe { + estimated_param_billions: 0.8, + execution_provider: "CUDAExecutionProvider".to_string(), + smoke_latency_ms: 10, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Basic); + } + + #[test] + fn classify_boundary_2b_model_as_moderate() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 2.0, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 100, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Moderate); + } + + #[test] + fn classify_boundary_latency_200ms_as_moderate() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 5.0, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 200, + vocab_size: 32000, + }; + assert_eq!(classify_capability(&probe), ModelCapabilityTier::Moderate); + } + + #[test] + fn tier_ordering() { + assert!(ModelCapabilityTier::Basic < ModelCapabilityTier::Moderate); + assert!(ModelCapabilityTier::Moderate < ModelCapabilityTier::Strong); + } + + #[test] + fn tier_serializes_lowercase() { + let json = serde_json::to_string(&ModelCapabilityTier::Basic).unwrap(); + assert_eq!(json, "\"basic\""); + let json = serde_json::to_string(&ModelCapabilityTier::Strong).unwrap(); + assert_eq!(json, "\"strong\""); + } + + #[test] + fn tier_deserializes_from_lowercase() { + let tier: ModelCapabilityTier = serde_json::from_str("\"moderate\"").unwrap(); + assert_eq!(tier, ModelCapabilityTier::Moderate); + } + + // ── Basic tier summary tests (#79) ── + + #[test] + fn basic_tier_summary_empty_findings() { + let summary = basic_tier_summary(&[]); + assert!(summary.starts_with("SUMMARY: 0 findings detected.")); + assert!(summary.contains("RISK: info")); + } + + #[test] + fn basic_tier_summary_with_findings() { + let findings = vec![ + make_finding( + "Active listeners", + FindingSeverity::High, + 0.80, + "scan_network", + "observation.listener_count", + ), + make_finding( + "Suspicious persistence", + FindingSeverity::Medium, + 0.70, + "inspect_persistence_locations", + "observation.suspicious_entry_count", + ), + ]; + + let summary = basic_tier_summary(&findings); + assert!(summary.starts_with("SUMMARY: 2 findings detected. Maximum severity: high.")); + assert!(summary.contains("FINDINGS:")); + assert!(summary.contains("1. Active listeners [high]")); + assert!(summary.contains("2. Suspicious persistence [medium]")); + assert!(summary.contains("RISK: high")); + assert!(summary.contains("ACTIONS:")); + } + + #[test] + fn basic_tier_summary_is_deterministic() { + let findings = vec![make_finding( + "Test", + FindingSeverity::Low, + 0.50, + "scan_network", + "a", + )]; + let a = basic_tier_summary(&findings); + let b = basic_tier_summary(&findings); + assert_eq!(a, b); + } + + // ── ModelCapabilityReport tests (#80) ── + + #[test] + fn capability_report_from_probe_roundtrips_json() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 1.2, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 350, + vocab_size: 32000, + }; + let report = ModelCapabilityReport::from_probe(&probe, ModelCapabilityTier::Basic); + let json = serde_json::to_string_pretty(&report).unwrap(); + assert!(json.contains("\"tier\": \"basic\"")); + assert!(json.contains("\"estimated_params_b\"")); + assert!(json.contains("\"execution_provider\"")); + assert!(json.contains("\"smoke_latency_ms\"")); + assert!(json.contains("\"vocab_size\"")); + // override should be absent when false + assert!(!json.contains("\"override\"")); + } + + #[test] + fn capability_report_override_flag_serialized() { + let probe = ModelCapabilityProbe { + estimated_param_billions: 1.2, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 350, + vocab_size: 32000, + }; + let mut report = ModelCapabilityReport::from_probe(&probe, ModelCapabilityTier::Strong); + report.r#override = true; + let json = serde_json::to_string_pretty(&report).unwrap(); + assert!(json.contains("\"override\": true")); + assert!(json.contains("\"tier\": \"strong\"")); + } } diff --git a/docs/schemas/examples/run-report.example.json b/docs/schemas/examples/run-report.example.json index f352498..fb82f1d 100644 --- a/docs/schemas/examples/run-report.example.json +++ b/docs/schemas/examples/run-report.example.json @@ -2,6 +2,14 @@ "contract_version": "1.0.0", "task": "Investigate unauthorized SSH keys", "case_id": "CASE-2026-IR-1001", + "max_severity": "medium", + "model_capability": { + "tier": "basic", + "estimated_params_b": 1.2, + "execution_provider": "CPUExecutionProvider", + "smoke_latency_ms": 350, + "vocab_size": 32000 + }, "live_fallback_decision": { "policy": "dry-run-on-error", "reason": "live inference failed and runtime fell back to dry-run", diff --git a/docs/schemas/run-report.schema.json b/docs/schemas/run-report.schema.json index a5009fb..f751e0a 100644 --- a/docs/schemas/run-report.schema.json +++ b/docs/schemas/run-report.schema.json @@ -21,6 +21,39 @@ "case_id": { "type": ["string", "null"] }, + "max_severity": { + "type": ["string", "null"], + "enum": ["info", "low", "medium", "high", "critical", null] + }, + "model_capability": { + "type": ["object", "null"], + "required": ["tier", "estimated_params_b", "execution_provider", "smoke_latency_ms", "vocab_size"], + "properties": { + "tier": { + "type": "string", + "enum": ["basic", "moderate", "strong"] + }, + "estimated_params_b": { + "type": "number", + "minimum": 0.0 + }, + "execution_provider": { + "type": "string" + }, + "smoke_latency_ms": { + "type": "integer", + "minimum": 0 + }, + "vocab_size": { + "type": "integer", + "minimum": 0 + }, + "override": { + "type": "boolean" + } + }, + "additionalProperties": false + }, "live_fallback_decision": { "type": ["object", "null"], "required": ["policy", "reason", "reason_code", "live_error", "fallback_mode"], diff --git a/inference_bridge/src/lib.rs b/inference_bridge/src/lib.rs index b6e8097..897ba23 100644 --- a/inference_bridge/src/lib.rs +++ b/inference_bridge/src/lib.rs @@ -7,6 +7,120 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::debug; +/// Raw probe signals extracted from a model without running full inference. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelCapabilityProbe { + /// Estimated parameter count in billions, derived from model file size. + pub estimated_param_billions: f32, + /// Execution provider assigned by the runtime (e.g. "CPUExecutionProvider"). + pub execution_provider: String, + /// Wall-clock latency of a single-token forward pass in milliseconds. + pub smoke_latency_ms: u64, + /// Vocabulary size extracted from the logits output tensor shape. + pub vocab_size: usize, +} + +impl Default for ModelCapabilityProbe { + fn default() -> Self { + Self { + estimated_param_billions: 0.0, + execution_provider: "CPUExecutionProvider".to_string(), + smoke_latency_ms: 999, + vocab_size: 0, + } + } +} + +/// Probe a model's capability signals without running full inference. +/// +/// On non-onnx builds, returns a sensible default (Basic-tier signals). +/// On onnx builds, extracts file size, EP, smoke latency, and vocab size. +pub fn probe_model_capability(config: &ModelConfig) -> ModelCapabilityProbe { + let estimated_param_billions = estimate_params_from_file_size(&config.model_path); + let execution_provider = detect_execution_provider(config); + let smoke_latency_ms = measure_smoke_latency(config); + let vocab_size = detect_vocab_size(config); + + ModelCapabilityProbe { + estimated_param_billions, + execution_provider, + smoke_latency_ms, + vocab_size, + } +} + +/// Estimate parameter count (in billions) from model file size. +/// Assumes ~2 bytes per parameter (float16/bfloat16 quantised models). +fn estimate_params_from_file_size(model_path: &PathBuf) -> f32 { + match std::fs::metadata(model_path) { + Ok(meta) => { + let bytes = meta.len() as f64; + // ~2 bytes per param for fp16/bf16; adjust for overhead (~10%). + let estimated_params = bytes / 2.2; + (estimated_params / 1_000_000_000.0) as f32 + } + Err(_) => 0.0, + } +} + +/// Detect which execution provider would be used for this config. +fn detect_execution_provider(config: &ModelConfig) -> String { + if config.vitis_config.is_some() { + "VitisAIExecutionProvider".to_string() + } else if cfg!(feature = "onnx") { + // Without Vitis config, ONNX Runtime defaults to CPU. + "CPUExecutionProvider".to_string() + } else { + "CPUExecutionProvider".to_string() + } +} + +/// Measure smoke latency. In dry-run or non-onnx builds, returns a default. +fn measure_smoke_latency(config: &ModelConfig) -> u64 { + if config.dry_run { + return 1; + } + // Without live ONNX session, estimate from file size heuristic: + // ~50ms per billion params on CPU as baseline estimate. + let params_b = estimate_params_from_file_size(&config.model_path); + if params_b > 0.0 { + (params_b * 50.0) as u64 + } else { + 999 + } +} + +/// Detect vocabulary size from model config. Returns 0 if unknown. +fn detect_vocab_size(config: &ModelConfig) -> usize { + // Try reading tokenizer.json to extract vocab size. + if let Some(tokenizer_path) = &config.tokenizer_path { + if let Ok(data) = std::fs::read_to_string(tokenizer_path) { + if let Ok(json) = serde_json::from_str::(&data) { + // HuggingFace tokenizer format: model.vocab has the vocab entries. + if let Some(vocab) = json + .get("model") + .and_then(|m| m.get("vocab")) + .and_then(|v| v.as_object()) + { + return vocab.len(); + } + // Alternative: added_tokens array length + base vocab. + if let Some(added) = json.get("added_tokens").and_then(|a| a.as_array()) { + if let Some(base) = json + .get("model") + .and_then(|m| m.get("merges")) + .and_then(|m| m.as_array()) + { + // BPE vocab ≈ merges + 256 byte tokens + added tokens + return base.len() + 256 + added.len(); + } + } + } + } + } + 0 +} + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct VitisEpConfig { pub config_file: Option,