diff --git a/CHANGELOG.md b/CHANGELOG.md index ac74955..ff2f509 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,9 @@ The format is inspired by Keep a Changelog and this project follows Semantic Ver ### Changed -- (none yet) +- **Provider-agnostic `ModelConfig`** (#49): replaced `vitis_config: Option` with generic `backend_override: Option` and `backend_config: HashMap`. `VitisEpConfig` is retained as a CLI-level helper with `into_backend_config()` / `from_backend_config()` conversion methods. +- **Vitis EP reads via `backend_config` map** (#50): `onnx_vitis` functions (`discover_ort_dylib_path`, `build_base_session_builder_with_provider`, `build_session_with_vitis_cascade`) now read config values from the generic `backend_config` map instead of the Vitis-specific struct. +- **CPU EP unblocked by config refactor** (#51): `CpuBackend` and all non-Vitis callers now use `backend_override: None, backend_config: Default::default()`, removing any coupling to Vitis types. ### Fixed diff --git a/api_server/src/routes.rs b/api_server/src/routes.rs index ec96fb3..5478a0d 100644 --- a/api_server/src/routes.rs +++ b/api_server/src/routes.rs @@ -353,7 +353,8 @@ async fn run_investigation(task: &str, max_steps: usize) -> anyhow::Result (Some("vitis".to_string()), cfg.into_backend_config()), + None => (None, Default::default()), + }; let compatibility = inspect_runtime_compatibility( &ModelConfig { model_path: runtime.model.clone(), @@ -4412,7 +4417,8 @@ fn run_model_pack_doctor_checks(runtime: &RuntimeConfig, report: &mut DoctorRepo max_new_tokens: 1, temperature: runtime.temperature, dry_run: false, - vitis_config: build_vitis_config(runtime), + backend_override: bo, + backend_config: bc, }, true, ); @@ -5734,13 +5740,18 @@ async fn run_agent_once(runtime: &RuntimeConfig, dry_run: bool) -> Result (Some("vitis".to_string()), cfg.into_backend_config()), + None => (None, Default::default()), + }; let model_config = ModelConfig { model_path: runtime.model.clone(), tokenizer_path: runtime.tokenizer.clone(), max_new_tokens: runtime.max_new_tokens, temperature: runtime.temperature, dry_run, - vitis_config, + backend_override, + backend_config, }; // Determine capability tier: override > probe > default. diff --git a/docs/upgrades.md b/docs/upgrades.md index d51037d..f709a1b 100644 --- a/docs/upgrades.md +++ b/docs/upgrades.md @@ -1,5 +1,58 @@ # Upgrade Notes +## Unreleased + +### Breaking/visible changes + +- **`ModelConfig` struct changed** (#49): the `vitis_config: Option` field is replaced by two new fields: + - `backend_override: Option` — optional backend name hint (e.g. `"vitis"`) + - `backend_config: HashMap` — generic key-value config map + + Both fields default to empty via `#[serde(default)]`, so TOML/JSON deserialization is backward-compatible if you don't set them. + +- **`VitisEpConfig` is still available** as a helper. Use `into_backend_config()` to convert to the new map and `from_backend_config()` to reconstruct from one. + +### Migration + +Before: +```rust +let config = ModelConfig { + // ... + vitis_config: Some(VitisEpConfig { + config_file: Some("/path/to/vitis.json".into()), + cache_dir: None, + cache_key: None, + }), +}; +``` + +After: +```rust +use std::collections::HashMap; + +let config = ModelConfig { + // ... + backend_override: Some("vitis".to_string()), + backend_config: HashMap::from([ + ("config_file".to_string(), "/path/to/vitis.json".to_string()), + ]), +}; +``` + +Or using the helper: +```rust +let vitis = VitisEpConfig { + config_file: Some("/path/to/vitis.json".into()), + cache_dir: None, + cache_key: None, +}; +let config = ModelConfig { + // ... + backend_override: Some("vitis".to_string()), + backend_config: vitis.into_backend_config(), +}; +``` + ## v1.3.0 ### Breaking/visible changes diff --git a/inference_bridge/src/backend.rs b/inference_bridge/src/backend.rs index e82f434..c87c71f 100644 --- a/inference_bridge/src/backend.rs +++ b/inference_bridge/src/backend.rs @@ -482,7 +482,8 @@ mod tests { max_new_tokens: 1, temperature: 0.0, dry_run: true, - vitis_config: None, + backend_override: None, + backend_config: Default::default(), }; let cpu = CpuBackend; let session = cpu.build_session(&config, &BackendOptions::new()); @@ -529,7 +530,8 @@ mod tests { max_new_tokens: 1, temperature: 0.0, dry_run: true, - vitis_config: None, + backend_override: None, + backend_config: Default::default(), }; let result = registry.build_session_with_fallback(&config, &BackendOptions::new(), None); @@ -547,7 +549,8 @@ mod tests { max_new_tokens: 1, temperature: 0.0, dry_run: true, - vitis_config: None, + backend_override: None, + backend_config: Default::default(), }; let result = registry.build_session_with_fallback(&config, &BackendOptions::new(), Some("CPU")); diff --git a/inference_bridge/src/lib.rs b/inference_bridge/src/lib.rs index 9c92930..d3d2bad 100644 --- a/inference_bridge/src/lib.rs +++ b/inference_bridge/src/lib.rs @@ -66,7 +66,9 @@ fn estimate_params_from_file_size(model_path: &PathBuf) -> f32 { /// Detect which execution provider would be used for this config. fn detect_execution_provider(config: &ModelConfig) -> String { - if config.vitis_config.is_some() { + if config.backend_override.as_deref() == Some("vitis") + || config.backend_config.contains_key("config_file") + { "VitisAIExecutionProvider".to_string() } else if cfg!(feature = "onnx") { // Without Vitis config, ONNX Runtime defaults to CPU. @@ -129,6 +131,32 @@ pub struct VitisEpConfig { pub cache_key: Option, } +impl VitisEpConfig { + /// Convert to a generic backend config map. + pub fn into_backend_config(self) -> std::collections::HashMap { + let mut map = std::collections::HashMap::new(); + if let Some(v) = self.config_file { + map.insert("config_file".to_string(), v); + } + if let Some(v) = self.cache_dir { + map.insert("cache_dir".to_string(), v); + } + if let Some(v) = self.cache_key { + map.insert("cache_key".to_string(), v); + } + map + } + + /// Reconstruct from a generic backend config map. + pub fn from_backend_config(map: &std::collections::HashMap) -> Self { + Self { + config_file: map.get("config_file").cloned(), + cache_dir: map.get("cache_dir").cloned(), + cache_key: map.get("cache_key").cloned(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelConfig { pub model_path: PathBuf, @@ -136,7 +164,12 @@ pub struct ModelConfig { pub max_new_tokens: usize, pub temperature: f32, pub dry_run: bool, - pub vitis_config: Option, + /// Explicit backend override (e.g., "cpu", "vitis", "cuda"). + #[serde(default)] + pub backend_override: Option, + /// Provider-specific key-value configuration. + #[serde(default)] + pub backend_config: std::collections::HashMap, } #[async_trait] @@ -356,7 +389,8 @@ mod tests { max_new_tokens: 16, temperature: 0.2, dry_run: true, - vitis_config: None, + backend_override: None, + backend_config: Default::default(), }) } diff --git a/inference_bridge/src/onnx_vitis.rs b/inference_bridge/src/onnx_vitis.rs index 4630a8d..5611b64 100644 --- a/inference_bridge/src/onnx_vitis.rs +++ b/inference_bridge/src/onnx_vitis.rs @@ -450,9 +450,8 @@ fn discover_ort_dylib_path(config: &ModelConfig) -> Option { let mut candidates = Vec::new(); if let Some(vitis_config_path) = config - .vitis_config - .as_ref() - .and_then(|cfg| cfg.config_file.as_deref()) + .backend_config + .get("config_file") { let vitis_config_path = PathBuf::from(vitis_config_path); if let Some(parent) = vitis_config_path.parent() { @@ -1203,16 +1202,14 @@ fn build_base_session_builder_with_provider( let mut vitis = ep::Vitis::default(); if use_vitis_provider { - if let Some(vitis_cfg) = &config.vitis_config { - if let Some(config_file) = &vitis_cfg.config_file { - vitis = vitis.with_config_file(config_file); - } - if let Some(cache_dir) = &vitis_cfg.cache_dir { - vitis = vitis.with_cache_dir(cache_dir); - } - if let Some(cache_key) = &vitis_cfg.cache_key { - vitis = vitis.with_cache_key(cache_key); - } + if let Some(config_file) = config.backend_config.get("config_file") { + vitis = vitis.with_config_file(config_file); + } + if let Some(cache_dir) = config.backend_config.get("cache_dir") { + vitis = vitis.with_cache_dir(cache_dir); + } + if let Some(cache_key) = config.backend_config.get("cache_key") { + vitis = vitis.with_cache_key(cache_key); } } @@ -1254,7 +1251,7 @@ fn build_session_with_vitis_cascade(config: &ModelConfig) -> Result { let force_cpu_provider = env_var_truthy("WRAITHRUN_FORCE_CPU_EP"); debug!( model = %config.model_path.display(), - has_vitis_config = config.vitis_config.is_some(), + has_vitis_config = config.backend_config.contains_key("config_file"), force_cpu_provider, "building Vitis ONNX Runtime session" ); @@ -2792,7 +2789,8 @@ mod tests { max_new_tokens: 1, temperature: 0.0, dry_run: false, - vitis_config: None, + backend_override: None, + backend_config: Default::default(), }; let report = inspect_runtime_compatibility(&config, true);