diff --git a/.github/workflows/wraithrun-scan.example.yml b/.github/workflows/wraithrun-scan.example.yml new file mode 100644 index 0000000..03974b7 --- /dev/null +++ b/.github/workflows/wraithrun-scan.example.yml @@ -0,0 +1,48 @@ +name: WraithRun Security Scan + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + # Nightly host triage at 02:00 UTC + - cron: '0 2 * * *' + workflow_dispatch: + +permissions: + contents: read + +jobs: + scan: + name: Security Investigation + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Run WraithRun scan + id: scan + uses: Shreyas582/wraithrun-action@v1 + with: + task: 'Triage this host for persistence mechanisms and suspicious accounts' + format: json + max-steps: 10 + fail-on-severity: high + + - name: Upload report artifact + if: always() + uses: actions/upload-artifact@v4 + with: + name: wraithrun-report + path: ${{ steps.scan.outputs.report-path }} + + - name: Summary + if: always() + run: | + echo "## WraithRun Scan Results" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "- **Findings:** ${{ steps.scan.outputs.finding-count }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Max Severity:** ${{ steps.scan.outputs.max-severity }}" >> "$GITHUB_STEP_SUMMARY" + echo "- **Exit Code:** ${{ steps.scan.outputs.exit-code }}" >> "$GITHUB_STEP_SUMMARY" diff --git a/CHANGELOG.md b/CHANGELOG.md index 96d1a57..ac74955 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,21 @@ The format is inspired by Keep a Changelog and this project follows Semantic Ver - (none yet) -## 1.2.0 - 2026-04-05 +## 1.3.0 - 2026-04-12 + +### Added + +- **CI/CD pipeline integration** (#103): first-party GitHub composite Action (`Shreyas582/wraithrun-action@v1`) with version resolution, binary caching, cross-platform install, scan execution, and JSON finding extraction. Also ships GitLab CI template, generic shell script for Jenkins/CircleCI, and an example GitHub Actions workflow. +- **CI integration guide** (`docs/ci-integration.md`): step-by-step docs for GitHub Actions, GitLab CI, and generic shell usage, covering exit code policy, output formats, scheduled scanning, and interpreting results. +- **`ExecutionProviderBackend` trait** (#47): hardware-agnostic backend abstraction in `inference_bridge::backend` with `name()`, `is_available()`, `priority()`, `build_session()`, and `diagnose()` methods. Includes `DiagnosticEntry` type for doctor integration and `InferenceSession` trait for provider-created sessions. +- **`ProviderRegistry`** (#48): runtime registry with `discover()`, `best_available()`, `get()`, `list()`, and `build_session_with_fallback()`. Auto-selects highest-priority available backend with cascading fallback on session init failure. +- **Built-in CPU backend**: always-available CPU execution provider (priority 0) with dry-run support and ONNX Runtime CPU session bridging. +- **Built-in Vitis backend** (cfg-gated): AMD Vitis AI NPU provider (priority 300, `vitis` feature) with environment-based availability detection and diagnostic checks. +- 12 new unit tests for backend trait, registry, and session functionality (245 total). + +### Changed + +- `inference_bridge` crate now exports `pub mod backend` alongside `pub mod onnx_vitis`. ### Added diff --git a/Cargo.lock b/Cargo.lock index 200a159..640b5b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,7 +83,7 @@ checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "api_server" -version = "1.2.0" +version = "1.3.0" dependencies = [ "anyhow", "axum", @@ -308,7 +308,7 @@ dependencies = [ [[package]] name = "core_engine" -version = "1.2.0" +version = "1.3.0" dependencies = [ "anyhow", "async-trait", @@ -372,7 +372,7 @@ dependencies = [ [[package]] name = "cyber_tools" -version = "1.2.0" +version = "1.3.0" dependencies = [ "async-trait", "serde", @@ -791,7 +791,7 @@ dependencies = [ [[package]] name = "inference_bridge" -version = "1.2.0" +version = "1.3.0" dependencies = [ "anyhow", "async-trait", @@ -2117,7 +2117,7 @@ dependencies = [ [[package]] name = "wraithrun" -version = "1.2.0" +version = "1.3.0" dependencies = [ "anyhow", "api_server", diff --git a/Cargo.toml b/Cargo.toml index 02ff4a8..573de40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ resolver = "2" [workspace.package] edition = "2021" -version = "1.2.0" +version = "1.3.0" license = "MIT" [workspace.dependencies] diff --git a/action.yml b/action.yml new file mode 100644 index 0000000..133db7f --- /dev/null +++ b/action.yml @@ -0,0 +1,174 @@ +name: 'WraithRun Security Scan' +description: 'Run a WraithRun automated security investigation in your CI pipeline' +author: 'Shreyas582' + +branding: + icon: 'shield' + color: 'purple' + +inputs: + version: + description: 'WraithRun version to install (e.g. "1.2.0" or "latest")' + required: false + default: 'latest' + task: + description: 'Investigation task description' + required: true + profile: + description: 'Named configuration profile' + required: false + default: '' + max-steps: + description: 'Maximum agent investigation steps' + required: false + default: '10' + format: + description: 'Output format: json, summary, markdown, narrative' + required: false + default: 'json' + fail-on-severity: + description: 'Fail the step if any finding meets or exceeds this severity (none, info, low, medium, high, critical)' + required: false + default: 'none' + extra-args: + description: 'Additional CLI arguments passed to wraithrun' + required: false + default: '' + +outputs: + report-path: + description: 'Path to the generated report file' + value: ${{ steps.run.outputs.report_path }} + finding-count: + description: 'Total number of findings' + value: ${{ steps.run.outputs.finding_count }} + max-severity: + description: 'Highest severity finding (or "none")' + value: ${{ steps.run.outputs.max_severity }} + exit-code: + description: 'Exit code from the WraithRun scan' + value: ${{ steps.run.outputs.exit_code }} + +runs: + using: 'composite' + steps: + - name: Determine version + id: version + shell: bash + run: | + if [ "${{ inputs.version }}" = "latest" ]; then + VERSION=$(curl -sS https://api.github.com/repos/Shreyas582/WraithRun/releases/latest | grep '"tag_name"' | head -1 | sed 's/.*"v\(.*\)".*/\1/') + echo "resolved=${VERSION}" >> "$GITHUB_OUTPUT" + else + echo "resolved=${{ inputs.version }}" >> "$GITHUB_OUTPUT" + fi + + - name: Cache WraithRun binary + id: cache + uses: actions/cache@v4 + with: + path: ~/.wraithrun-bin + key: wraithrun-${{ runner.os }}-${{ steps.version.outputs.resolved }} + + - name: Install WraithRun + if: steps.cache.outputs.cache-hit != 'true' + shell: bash + run: | + set -euo pipefail + VERSION="${{ steps.version.outputs.resolved }}" + mkdir -p ~/.wraithrun-bin + + case "${{ runner.os }}" in + Linux) + ASSET="wraithrun-${VERSION}-x86_64-unknown-linux-gnu.tar.gz" + curl -sSL "https://github.com/Shreyas582/WraithRun/releases/download/v${VERSION}/${ASSET}" -o /tmp/wraithrun.tar.gz + tar -xzf /tmp/wraithrun.tar.gz -C ~/.wraithrun-bin + ;; + macOS) + ASSET="wraithrun-${VERSION}-x86_64-apple-darwin.tar.gz" + curl -sSL "https://github.com/Shreyas582/WraithRun/releases/download/v${VERSION}/${ASSET}" -o /tmp/wraithrun.tar.gz + tar -xzf /tmp/wraithrun.tar.gz -C ~/.wraithrun-bin + ;; + Windows) + ASSET="wraithrun-${VERSION}-x86_64-pc-windows-msvc.zip" + curl -sSL "https://github.com/Shreyas582/WraithRun/releases/download/v${VERSION}/${ASSET}" -o "$TEMP/wraithrun.zip" + unzip -o "$TEMP/wraithrun.zip" -d ~/.wraithrun-bin + ;; + esac + + - name: Add to PATH + shell: bash + run: echo "$HOME/.wraithrun-bin" >> "$GITHUB_PATH" + + - name: Run WraithRun scan + id: run + shell: bash + run: | + set -uo pipefail + REPORT_PATH="${{ runner.temp }}/wraithrun-report.json" + + # Build CLI arguments + ARGS=(--task "${{ inputs.task }}" --format "${{ inputs.format }}" --max-steps "${{ inputs.max-steps }}") + + if [ -n "${{ inputs.profile }}" ]; then + ARGS+=(--profile "${{ inputs.profile }}") + fi + + # Exit policy + if [ "${{ inputs.fail-on-severity }}" != "none" ]; then + ARGS+=(--exit-policy severity-threshold --exit-threshold "${{ inputs.fail-on-severity }}") + fi + + # Extra user args + if [ -n "${{ inputs.extra-args }}" ]; then + # shellcheck disable=SC2206 + ARGS+=(${{ inputs.extra-args }}) + fi + + # Run and capture exit code + EXIT_CODE=0 + wraithrun "${ARGS[@]}" > "$REPORT_PATH" 2>&1 || EXIT_CODE=$? + + # Extract finding count and max severity from JSON report + FINDING_COUNT=0 + MAX_SEVERITY="none" + if [ -f "$REPORT_PATH" ] && command -v python3 &>/dev/null; then + FINDING_COUNT=$(python3 -c " + import json, sys + try: + data = json.load(open('$REPORT_PATH')) + findings = data.get('findings', []) + data.get('supplementary_findings', []) + print(len(findings)) + except Exception: + print(0) + " 2>/dev/null || echo 0) + + MAX_SEVERITY=$(python3 -c " + import json, sys + try: + data = json.load(open('$REPORT_PATH')) + findings = data.get('findings', []) + data.get('supplementary_findings', []) + order = {'critical':5,'high':4,'medium':3,'low':2,'info':1} + best = 0 + best_name = 'none' + for f in findings: + sev = f.get('severity','info').lower() + if order.get(sev,0) > best: + best = order[sev] + best_name = sev + print(best_name) + except Exception: + print('none') + " 2>/dev/null || echo "none") + fi + + echo "report_path=${REPORT_PATH}" >> "$GITHUB_OUTPUT" + echo "finding_count=${FINDING_COUNT}" >> "$GITHUB_OUTPUT" + echo "max_severity=${MAX_SEVERITY}" >> "$GITHUB_OUTPUT" + echo "exit_code=${EXIT_CODE}" >> "$GITHUB_OUTPUT" + + # Fail the step if exit code is non-zero and severity policy is active + if [ "$EXIT_CODE" -ne 0 ] && [ "${{ inputs.fail-on-severity }}" != "none" ]; then + echo "::error::WraithRun found findings at or above '${{ inputs.fail-on-severity }}' severity (exit code ${EXIT_CODE})" + exit "$EXIT_CODE" + fi diff --git a/ci-templates/gitlab-ci.yml b/ci-templates/gitlab-ci.yml new file mode 100644 index 0000000..d09fb14 --- /dev/null +++ b/ci-templates/gitlab-ci.yml @@ -0,0 +1,42 @@ +# GitLab CI template for WraithRun security scanning. +# Include this in your .gitlab-ci.yml: +# include: +# - remote: https://raw.githubusercontent.com/Shreyas582/WraithRun/main/ci-templates/gitlab-ci.yml + +wraithrun-scan: + stage: test + image: ubuntu:22.04 + variables: + WRAITHRUN_VERSION: "latest" + WRAITHRUN_TASK: "Triage this host for persistence mechanisms and suspicious accounts" + WRAITHRUN_FORMAT: "json" + WRAITHRUN_MAX_STEPS: "10" + WRAITHRUN_FAIL_SEVERITY: "none" + before_script: + - apt-get update -qq && apt-get install -y -qq curl python3 > /dev/null + - | + if [ "$WRAITHRUN_VERSION" = "latest" ]; then + WRAITHRUN_VERSION=$(curl -sS https://api.github.com/repos/Shreyas582/WraithRun/releases/latest | grep '"tag_name"' | head -1 | sed 's/.*"v\(.*\)".*/\1/') + fi + - curl -sSL "https://github.com/Shreyas582/WraithRun/releases/download/v${WRAITHRUN_VERSION}/wraithrun-${WRAITHRUN_VERSION}-x86_64-unknown-linux-gnu.tar.gz" -o /tmp/wraithrun.tar.gz + - tar -xzf /tmp/wraithrun.tar.gz -C /usr/local/bin + script: + - | + ARGS="--task \"${WRAITHRUN_TASK}\" --format ${WRAITHRUN_FORMAT} --max-steps ${WRAITHRUN_MAX_STEPS}" + if [ "$WRAITHRUN_FAIL_SEVERITY" != "none" ]; then + ARGS="${ARGS} --exit-policy severity-threshold --exit-threshold ${WRAITHRUN_FAIL_SEVERITY}" + fi + eval wraithrun ${ARGS} > wraithrun-report.json 2>&1 || EXIT_CODE=$? + echo "Exit code: ${EXIT_CODE:-0}" + if [ "${EXIT_CODE:-0}" -ne 0 ] && [ "$WRAITHRUN_FAIL_SEVERITY" != "none" ]; then + exit ${EXIT_CODE} + fi + artifacts: + when: always + paths: + - wraithrun-report.json + expire_in: 30 days + rules: + - if: $CI_PIPELINE_SOURCE == "merge_request_event" + - if: $CI_COMMIT_BRANCH == $CI_DEFAULT_BRANCH + - if: $CI_PIPELINE_SOURCE == "schedule" diff --git a/ci-templates/wraithrun-scan.sh b/ci-templates/wraithrun-scan.sh new file mode 100644 index 0000000..56c7780 --- /dev/null +++ b/ci-templates/wraithrun-scan.sh @@ -0,0 +1,85 @@ +#!/usr/bin/env bash +# WraithRun CI Scanner — generic shell script for Jenkins, CircleCI, etc. +# +# Usage: +# WRAITHRUN_TASK="Investigate host" ./ci-templates/wraithrun-scan.sh +# +# Environment variables: +# WRAITHRUN_VERSION Version to install (default: latest) +# WRAITHRUN_TASK Investigation task (required) +# WRAITHRUN_FORMAT Output format: json|summary|markdown|narrative (default: json) +# WRAITHRUN_MAX_STEPS Max investigation steps (default: 10) +# WRAITHRUN_FAIL_SEVERITY Fail threshold: none|info|low|medium|high|critical (default: none) +# WRAITHRUN_EXTRA_ARGS Additional CLI arguments +# WRAITHRUN_REPORT_PATH Output report path (default: ./wraithrun-report.json) + +set -euo pipefail + +: "${WRAITHRUN_TASK:?WRAITHRUN_TASK is required}" +: "${WRAITHRUN_VERSION:=latest}" +: "${WRAITHRUN_FORMAT:=json}" +: "${WRAITHRUN_MAX_STEPS:=10}" +: "${WRAITHRUN_FAIL_SEVERITY:=none}" +: "${WRAITHRUN_EXTRA_ARGS:=}" +: "${WRAITHRUN_REPORT_PATH:=./wraithrun-report.json}" + +INSTALL_DIR="${HOME}/.wraithrun-bin" + +# --- Install WraithRun --------------------------------------------------- + +install_wraithrun() { + local version="$1" + if [ "${version}" = "latest" ]; then + version=$(curl -sS https://api.github.com/repos/Shreyas582/WraithRun/releases/latest \ + | grep '"tag_name"' | head -1 | sed 's/.*"v\(.*\)".*/\1/') + fi + echo "Installing WraithRun v${version}..." + + mkdir -p "${INSTALL_DIR}" + local os + os=$(uname -s) + case "${os}" in + Linux) asset="wraithrun-${version}-x86_64-unknown-linux-gnu.tar.gz" ;; + Darwin) asset="wraithrun-${version}-x86_64-apple-darwin.tar.gz" ;; + *) echo "Unsupported OS: ${os}"; exit 1 ;; + esac + + curl -sSL "https://github.com/Shreyas582/WraithRun/releases/download/v${version}/${asset}" \ + -o /tmp/wraithrun.tar.gz + tar -xzf /tmp/wraithrun.tar.gz -C "${INSTALL_DIR}" + rm -f /tmp/wraithrun.tar.gz + export PATH="${INSTALL_DIR}:${PATH}" + echo "Installed: $(wraithrun --version)" +} + +# --- Run Scan ------------------------------------------------------------- + +if ! command -v wraithrun &>/dev/null; then + install_wraithrun "${WRAITHRUN_VERSION}" +fi + +ARGS=(--task "${WRAITHRUN_TASK}" --format "${WRAITHRUN_FORMAT}" --max-steps "${WRAITHRUN_MAX_STEPS}") + +if [ "${WRAITHRUN_FAIL_SEVERITY}" != "none" ]; then + ARGS+=(--exit-policy severity-threshold --exit-threshold "${WRAITHRUN_FAIL_SEVERITY}") +fi + +# shellcheck disable=SC2206 +if [ -n "${WRAITHRUN_EXTRA_ARGS}" ]; then + ARGS+=(${WRAITHRUN_EXTRA_ARGS}) +fi + +EXIT_CODE=0 +wraithrun "${ARGS[@]}" > "${WRAITHRUN_REPORT_PATH}" 2>&1 || EXIT_CODE=$? + +echo "" +echo "=== WraithRun Scan Complete ===" +echo "Report: ${WRAITHRUN_REPORT_PATH}" +echo "Exit code: ${EXIT_CODE}" + +if [ "${EXIT_CODE}" -ne 0 ] && [ "${WRAITHRUN_FAIL_SEVERITY}" != "none" ]; then + echo "FAILED: Findings at or above '${WRAITHRUN_FAIL_SEVERITY}' severity detected." + exit "${EXIT_CODE}" +fi + +exit 0 diff --git a/docs/ci-integration.md b/docs/ci-integration.md new file mode 100644 index 0000000..d4baf95 --- /dev/null +++ b/docs/ci-integration.md @@ -0,0 +1,132 @@ +# Integrating WraithRun in CI/CD + +Run automated security investigations on every push, pull request, or schedule. + +## GitHub Actions + +Use the official WraithRun Action: + +```yaml +- name: Run WraithRun scan + uses: Shreyas582/wraithrun-action@v1 + with: + task: 'Triage this host for persistence mechanisms' + format: json + max-steps: 10 + fail-on-severity: high +``` + +### Action inputs + +| Input | Required | Default | Description | +|-------------------|----------|----------|------------------------------------------------------------| +| `version` | no | `latest` | WraithRun version to install | +| `task` | **yes** | — | Investigation task description | +| `profile` | no | — | Named configuration profile | +| `max-steps` | no | `10` | Maximum agent investigation steps | +| `format` | no | `json` | Output format: `json`, `summary`, `markdown`, `narrative` | +| `fail-on-severity`| no | `none` | Fail threshold: `none`, `info`, `low`, `medium`, `high`, `critical` | +| `extra-args` | no | — | Additional CLI arguments | + +### Action outputs + +| Output | Description | +|-----------------|-----------------------------------------| +| `report-path` | Path to the generated report file | +| `finding-count` | Total number of findings | +| `max-severity` | Highest finding severity (or `"none"`) | +| `exit-code` | WraithRun process exit code | + +### Full workflow example + +See [`.github/workflows/wraithrun-scan.example.yml`](https://github.com/Shreyas582/WraithRun/blob/main/.github/workflows/wraithrun-scan.example.yml) for a complete example with artifact upload, step summary, and scheduled nightly scans. + +## GitLab CI + +Include the template or copy it into your `.gitlab-ci.yml`: + +```yaml +include: + - remote: https://raw.githubusercontent.com/Shreyas582/WraithRun/main/ci-templates/gitlab-ci.yml +``` + +Override variables to customize: + +```yaml +wraithrun-scan: + variables: + WRAITHRUN_TASK: "Check for unauthorized SSH keys" + WRAITHRUN_FAIL_SEVERITY: "medium" +``` + +## Jenkins / CircleCI / Generic + +Use the shell script in your pipeline: + +```bash +export WRAITHRUN_TASK="Investigate host for persistence" +export WRAITHRUN_FAIL_SEVERITY="high" +bash ci-templates/wraithrun-scan.sh +``` + +Or install directly: + +```bash +curl -sSL https://github.com/Shreyas582/WraithRun/releases/download/v1.2.0/wraithrun-1.2.0-x86_64-unknown-linux-gnu.tar.gz | tar -xz -C /usr/local/bin +wraithrun --task "Investigate host" --format json --exit-policy severity-threshold --exit-threshold high +``` + +## Exit code policy + +WraithRun supports exit code policies for CI gate decisions: + +| Flag | Values | Description | +|--------------------|--------------------------------------|------------------------------------------| +| `--exit-policy` | `none`, `severity-threshold` | When to use a non-zero exit code | +| `--exit-threshold` | `info`, `low`, `medium`, `high`, `critical` | Minimum severity to trigger failure | + +When `--exit-policy severity-threshold` is set and any finding meets or exceeds the threshold, WraithRun exits with code 1. This maps to a failed step in all CI systems. + +## Output formats + +| Format | Best for | +|-------------|---------------------------------| +| `json` | Machine parsing, dashboards | +| `summary` | Quick terminal overview | +| `markdown` | PR comments, documentation | +| `narrative` | Executive/stakeholder reporting | + +The `json` format follows the schema in [`docs/schemas/run-report.schema.json`](schemas/run-report.schema.json). See [Automation Contracts](automation-contracts.md) for full contract details. + +## Scheduled scanning + +### Nightly host triage + +```yaml +on: + schedule: + - cron: '0 2 * * *' # 02:00 UTC daily +``` + +### Weekly persistence check + +```yaml +on: + schedule: + - cron: '0 6 * * 1' # 06:00 UTC every Monday +``` + +## Interpreting results + +1. **Check exit code** — non-zero means findings exceeded your threshold. +2. **Parse JSON report** — `findings` array contains all discovered issues. +3. **Review severity** — each finding has a `severity` field: `critical`, `high`, `medium`, `low`, `info`. +4. **Check confidence** — `confidence_label` indicates how certain the tool is. +5. **Follow evidence** — each finding includes an `evidence` field linking to tool observations. + +## Tips + +- Start with `fail-on-severity: critical` and lower the threshold as you remediate findings. +- Use `--profile` to run pre-configured investigation templates. +- Upload reports as artifacts for audit trail. +- Post `--format markdown` output as PR comments for visibility. diff --git a/docs/index.md b/docs/index.md index a7c19f5..a5ab47d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -11,6 +11,7 @@ Use this documentation to install, run, and operate WraithRun in your own enviro - [CLI Reference](cli-reference.md): all command-line options. - [Tool Reference](tool-reference.md): built-in tool behavior and expected outputs. - [Security and Sandbox](security-sandbox.md): policy controls and environment variables. +- [CI/CD Integration](ci-integration.md): run WraithRun in GitHub Actions, GitLab CI, Jenkins. - [Troubleshooting](troubleshooting.md): common errors and fixes. ## Investigation Playbooks diff --git a/docs/upgrades.md b/docs/upgrades.md index 3a5c7c5..d51037d 100644 --- a/docs/upgrades.md +++ b/docs/upgrades.md @@ -1,5 +1,65 @@ # Upgrade Notes +## v1.3.0 + +### Breaking/visible changes + +- `inference_bridge` now exports a new `backend` module. This is additive and fully backward-compatible — the existing `InferenceEngine` trait and `OnnxVitisEngine` are unchanged. +- The `backend::InferenceSession` trait introduces a synchronous `generate()` method that parallels the existing async `InferenceEngine::generate()`. Downstream callers can adopt it incrementally. + +### New infrastructure + +- **GitHub composite Action** (`action.yml`): use `Shreyas582/wraithrun-action@v1` in your CI workflows to run WraithRun scans with binary caching and cross-platform support. +- **CI templates**: GitLab CI (`ci-templates/gitlab-ci.yml`) and generic shell (`ci-templates/wraithrun-scan.sh`) for Jenkins, CircleCI, and other platforms. +- **CI integration guide** (`docs/ci-integration.md`): comprehensive setup docs for all supported CI systems. + +### New types in `inference_bridge::backend` + +| Type | Purpose | +|------|---------| +| `ExecutionProviderBackend` trait | Hardware-agnostic backend abstraction | +| `InferenceSession` trait | Provider-created inference session | +| `ProviderRegistry` | Runtime discovery and selection | +| `DiagnosticEntry` / `DiagnosticSeverity` | Backend self-check diagnostics | +| `ProviderInfo` | Backend metadata for listing | +| `BackendOptions` | Provider-specific config passthrough | +| `CpuBackend` | Built-in CPU provider (always available) | +| `VitisBackend` | Built-in Vitis NPU provider (cfg-gated) | + +### Migration examples + +To use the new backend registry: + +```rust +use inference_bridge::backend::{ProviderRegistry, BackendOptions}; +use inference_bridge::ModelConfig; + +let registry = ProviderRegistry::discover(); + +// List available backends +for info in registry.list() { + println!("{}: available={}, priority={}", info.name, info.available, info.priority); +} + +// Auto-select best backend and build a session +let config = ModelConfig { /* ... */ }; +let (backend_name, session) = registry + .build_session_with_fallback(&config, &BackendOptions::new(), None) + .expect("no backend available"); +println!("Using backend: {backend_name}"); + +let output = session.generate("Analyze this host", 512)?; +``` + +To integrate WraithRun into GitHub Actions CI: + +```yaml +- uses: Shreyas582/wraithrun-action@v1 + with: + task: "Quick triage of ${{ github.sha }}" + fail-on-severity: high +``` + ## v1.2.0 ### Breaking/visible changes diff --git a/inference_bridge/src/backend.rs b/inference_bridge/src/backend.rs new file mode 100644 index 0000000..e82f434 --- /dev/null +++ b/inference_bridge/src/backend.rs @@ -0,0 +1,558 @@ +//! Execution provider backend abstraction. +//! +//! This module defines the [`ExecutionProviderBackend`] trait that decouples +//! the inference loop from any specific hardware execution provider. Backends +//! register themselves in a [`ProviderRegistry`] and are auto-selected based +//! on availability and priority. + +use std::collections::HashMap; +use std::fmt; + +use serde::{Deserialize, Serialize}; + +use crate::ModelConfig; + +// --------------------------------------------------------------------------- +// Diagnostic types +// --------------------------------------------------------------------------- + +/// Severity of a provider diagnostic entry. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DiagnosticSeverity { + Pass, + Warn, + Fail, +} + +impl fmt::Display for DiagnosticSeverity { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Pass => write!(f, "pass"), + Self::Warn => write!(f, "warn"), + Self::Fail => write!(f, "fail"), + } + } +} + +/// A single diagnostic entry produced by a backend's self-check. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiagnosticEntry { + pub severity: DiagnosticSeverity, + pub check: String, + pub message: String, +} + +impl DiagnosticEntry { + pub fn pass(check: impl Into, message: impl Into) -> Self { + Self { + severity: DiagnosticSeverity::Pass, + check: check.into(), + message: message.into(), + } + } + + pub fn warn(check: impl Into, message: impl Into) -> Self { + Self { + severity: DiagnosticSeverity::Warn, + check: check.into(), + message: message.into(), + } + } + + pub fn fail(check: impl Into, message: impl Into) -> Self { + Self { + severity: DiagnosticSeverity::Fail, + check: check.into(), + message: message.into(), + } + } +} + +// --------------------------------------------------------------------------- +// Backend trait +// --------------------------------------------------------------------------- + +/// Information about a registered backend, returned by the registry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderInfo { + pub name: String, + pub priority: u32, + pub available: bool, +} + +/// Abstraction over a hardware execution provider. +/// +/// Backends report their availability at runtime (not just compile time) so +/// the registry can probe what is actually present on the host. Each backend +/// can produce an inference session and provider-specific diagnostics. +pub trait ExecutionProviderBackend: Send + Sync { + /// Human-readable name (e.g. "CPU", "AMD Vitis NPU", "NVIDIA CUDA"). + fn name(&self) -> &str; + + /// Whether this backend is available on the current system. + /// + /// Backends should probe hardware/driver presence, not just check compile + /// flags. A compiled backend where the driver is missing returns `false`. + fn is_available(&self) -> bool; + + /// Priority for auto-selection. Higher values are preferred. + /// + /// Suggested baseline: CPU = 0, DirectML = 100, CoreML = 100, + /// CUDA = 200, Vitis NPU = 300. + fn priority(&self) -> u32; + + /// Provider-specific configuration keys that this backend reads from + /// [`BackendOptions`] (e.g. `"device_id"`, `"config_file"`). + fn config_keys(&self) -> &[&str] { + &[] + } + + /// Run provider-specific diagnostic checks. + fn diagnose(&self) -> Vec; + + /// Create a ready-to-use inference session for the given model config + /// and provider-specific options. + /// + /// This is the primary extension point. The provider translates + /// [`ModelConfig`] + [`BackendOptions`] into whatever internal session + /// type the runtime needs. + fn build_session( + &self, + config: &ModelConfig, + options: &BackendOptions, + ) -> anyhow::Result>; +} + +/// Provider-specific options passed through from CLI/config. +/// +/// This is a string-keyed map so that new backends can read their own config +/// keys without changing the core types. +pub type BackendOptions = HashMap; + +// --------------------------------------------------------------------------- +// Session trait +// --------------------------------------------------------------------------- + +/// A provider-created inference session. +/// +/// The inference loop calls `generate` regardless of which backend produced +/// the session. +pub trait InferenceSession: Send + Sync { + /// Generate text from a prompt using this session. + fn generate(&self, prompt: &str, max_new_tokens: usize) -> anyhow::Result; +} + +// --------------------------------------------------------------------------- +// Provider registry +// --------------------------------------------------------------------------- + +/// Runtime registry of execution provider backends. +/// +/// Created once at startup, it discovers which compile-time-enabled backends +/// are available and provides selection by priority or by name. +pub struct ProviderRegistry { + backends: Vec>, +} + +impl ProviderRegistry { + /// Build a registry with all compile-time-enabled backends. + /// + /// Each backend probes its own availability. The registry stores all + /// backends (available or not) for diagnostic listing. + pub fn discover() -> Self { + let mut backends: Vec> = Vec::new(); + + // CPU is always available. + backends.push(Box::new(CpuBackend)); + + // Vitis backend (only when compiled with the `vitis` feature). + #[cfg(feature = "vitis")] + backends.push(Box::new(VitisBackend)); + + Self { backends } + } + + /// Returns the highest-priority available backend. + pub fn best_available(&self) -> Option<&dyn ExecutionProviderBackend> { + self.backends + .iter() + .filter(|b| b.is_available()) + .max_by_key(|b| b.priority()) + .map(|b| b.as_ref()) + } + + /// Returns a specific backend by name (case-insensitive). + pub fn get(&self, name: &str) -> Option<&dyn ExecutionProviderBackend> { + let name_lower = name.to_ascii_lowercase(); + self.backends + .iter() + .find(|b| b.name().to_ascii_lowercase() == name_lower) + .map(|b| b.as_ref()) + } + + /// Lists all registered backends with availability status. + pub fn list(&self) -> Vec { + self.backends + .iter() + .map(|b| ProviderInfo { + name: b.name().to_string(), + priority: b.priority(), + available: b.is_available(), + }) + .collect() + } + + /// Lists the names of all available backends. + pub fn available_names(&self) -> Vec { + self.backends + .iter() + .filter(|b| b.is_available()) + .map(|b| b.name().to_string()) + .collect() + } + + /// Try to build a session using the specified backend, with automatic + /// fallback to the next-best available backend on failure. + pub fn build_session_with_fallback( + &self, + config: &ModelConfig, + options: &BackendOptions, + preferred: Option<&str>, + ) -> anyhow::Result<(String, Box)> { + // If a preferred backend is specified, try it first. + if let Some(name) = preferred { + if let Some(backend) = self.get(name) { + if backend.is_available() { + match backend.build_session(config, options) { + Ok(session) => return Ok((backend.name().to_string(), session)), + Err(e) => { + tracing::warn!( + backend = backend.name(), + error = %e, + "preferred backend failed, trying fallback" + ); + } + } + } + } + } + + // Auto-select: try backends by descending priority. + let mut candidates: Vec<&dyn ExecutionProviderBackend> = self + .backends + .iter() + .filter(|b| b.is_available()) + .map(|b| b.as_ref()) + .collect(); + candidates.sort_by(|a, b| b.priority().cmp(&a.priority())); + + for backend in candidates { + match backend.build_session(config, options) { + Ok(session) => return Ok((backend.name().to_string(), session)), + Err(e) => { + tracing::warn!( + backend = backend.name(), + error = %e, + "backend session init failed, trying next" + ); + } + } + } + + anyhow::bail!("no execution provider backend could build a session") + } +} + +impl fmt::Debug for ProviderRegistry { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ProviderRegistry") + .field("backends", &self.list()) + .finish() + } +} + +// --------------------------------------------------------------------------- +// Built-in: CPU backend +// --------------------------------------------------------------------------- + +/// CPU-only execution provider. Always available. +pub struct CpuBackend; + +impl ExecutionProviderBackend for CpuBackend { + fn name(&self) -> &str { + "CPU" + } + + fn is_available(&self) -> bool { + true + } + + fn priority(&self) -> u32 { + 0 + } + + fn diagnose(&self) -> Vec { + vec![DiagnosticEntry::pass( + "cpu-backend", + "CPU execution provider is always available", + )] + } + + fn build_session( + &self, + config: &ModelConfig, + _options: &BackendOptions, + ) -> anyhow::Result> { + if config.dry_run { + return Ok(Box::new(DryRunSession)); + } + + #[cfg(feature = "onnx")] + { + // Delegate to the existing onnx_vitis module for CPU session building. + // This is a proof-of-concept bridge — full extraction happens in #51. + Ok(Box::new(OnnxCpuSession { + config: config.clone(), + })) + } + + #[cfg(not(feature = "onnx"))] + { + Ok(Box::new(DryRunSession)) + } + } +} + +// --------------------------------------------------------------------------- +// Built-in: Vitis backend (cfg-gated) +// --------------------------------------------------------------------------- + +/// AMD Vitis AI NPU execution provider. +#[cfg(feature = "vitis")] +pub struct VitisBackend; + +#[cfg(feature = "vitis")] +impl ExecutionProviderBackend for VitisBackend { + fn name(&self) -> &str { + "AMD Vitis NPU" + } + + fn is_available(&self) -> bool { + // Check if Vitis runtime is discoverable. + // Full hardware probe will be implemented in #50. + std::env::var("RYZEN_AI_INSTALLER_PATH").is_ok() + || std::env::var("XLNX_VART_FIRMWARE").is_ok() + || cfg!(feature = "vitis") + } + + fn priority(&self) -> u32 { + 300 + } + + fn config_keys(&self) -> &[&str] { + &["config_file", "cache_dir", "cache_key"] + } + + fn diagnose(&self) -> Vec { + let mut entries = vec![]; + if std::env::var("RYZEN_AI_INSTALLER_PATH").is_ok() { + entries.push(DiagnosticEntry::pass( + "vitis-sdk", + "RYZEN_AI_INSTALLER_PATH is set", + )); + } else { + entries.push(DiagnosticEntry::warn( + "vitis-sdk", + "RYZEN_AI_INSTALLER_PATH not set; Vitis NPU may not be available", + )); + } + entries + } + + fn build_session( + &self, + config: &ModelConfig, + _options: &BackendOptions, + ) -> anyhow::Result> { + if config.dry_run { + return Ok(Box::new(DryRunSession)); + } + // Proof of concept — full Vitis session extraction happens in #50. + // For now, delegate to the existing monolithic path via OnnxCpuSession. + #[cfg(feature = "onnx")] + { + Ok(Box::new(OnnxCpuSession { + config: config.clone(), + })) + } + + #[cfg(not(feature = "onnx"))] + { + anyhow::bail!("Vitis backend requires the 'onnx' feature") + } + } +} + +// --------------------------------------------------------------------------- +// Session implementations +// --------------------------------------------------------------------------- + +/// Dry-run session that returns a placeholder response. +struct DryRunSession; + +impl InferenceSession for DryRunSession { + fn generate(&self, _prompt: &str, _max_new_tokens: usize) -> anyhow::Result { + Ok("Dry-run session: no live inference performed.".to_string()) + } +} + +/// ONNX CPU session that delegates to the existing `onnx_vitis::run_prompt`. +#[cfg(feature = "onnx")] +struct OnnxCpuSession { + config: ModelConfig, +} + +#[cfg(feature = "onnx")] +impl InferenceSession for OnnxCpuSession { + fn generate(&self, prompt: &str, _max_new_tokens: usize) -> anyhow::Result { + crate::onnx_vitis::run_prompt(&self.config, prompt) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cpu_backend_is_always_available() { + let cpu = CpuBackend; + assert!(cpu.is_available()); + assert_eq!(cpu.name(), "CPU"); + assert_eq!(cpu.priority(), 0); + } + + #[test] + fn cpu_diagnostics_pass() { + let cpu = CpuBackend; + let diags = cpu.diagnose(); + assert_eq!(diags.len(), 1); + assert_eq!(diags[0].severity, DiagnosticSeverity::Pass); + } + + #[test] + fn registry_discover_includes_cpu() { + let registry = ProviderRegistry::discover(); + let list = registry.list(); + assert!(!list.is_empty()); + assert!(list.iter().any(|p| p.name == "CPU" && p.available)); + } + + #[test] + fn registry_best_available_returns_something() { + let registry = ProviderRegistry::discover(); + let best = registry.best_available(); + assert!(best.is_some()); + } + + #[test] + fn registry_get_by_name_case_insensitive() { + let registry = ProviderRegistry::discover(); + assert!(registry.get("cpu").is_some()); + assert!(registry.get("CPU").is_some()); + assert!(registry.get("nonexistent").is_none()); + } + + #[test] + fn registry_available_names_includes_cpu() { + let registry = ProviderRegistry::discover(); + let names = registry.available_names(); + assert!(names.contains(&"CPU".to_string())); + } + + #[test] + fn dry_run_session_build() { + let config = ModelConfig { + model_path: std::path::PathBuf::from("test.onnx"), + tokenizer_path: None, + max_new_tokens: 1, + temperature: 0.0, + dry_run: true, + vitis_config: None, + }; + let cpu = CpuBackend; + let session = cpu.build_session(&config, &BackendOptions::new()); + assert!(session.is_ok()); + } + + #[test] + fn dry_run_session_generates() { + let session = DryRunSession; + let result = session.generate("test prompt", 10); + assert!(result.is_ok()); + assert!(result.unwrap().contains("Dry-run")); + } + + #[test] + fn diagnostic_entry_constructors() { + let pass = DiagnosticEntry::pass("check-a", "all good"); + assert_eq!(pass.severity, DiagnosticSeverity::Pass); + + let warn = DiagnosticEntry::warn("check-b", "maybe bad"); + assert_eq!(warn.severity, DiagnosticSeverity::Warn); + + let fail = DiagnosticEntry::fail("check-c", "broken"); + assert_eq!(fail.severity, DiagnosticSeverity::Fail); + } + + #[test] + fn provider_info_serializes() { + let info = ProviderInfo { + name: "CPU".to_string(), + priority: 0, + available: true, + }; + let json = serde_json::to_string(&info).unwrap(); + assert!(json.contains("\"CPU\"")); + } + + #[test] + fn build_session_with_fallback_works() { + let registry = ProviderRegistry::discover(); + let config = ModelConfig { + model_path: std::path::PathBuf::from("test.onnx"), + tokenizer_path: None, + max_new_tokens: 1, + temperature: 0.0, + dry_run: true, + vitis_config: None, + }; + let result = + registry.build_session_with_fallback(&config, &BackendOptions::new(), None); + assert!(result.is_ok()); + let (backend_name, _session) = result.unwrap(); + assert!(!backend_name.is_empty()); + } + + #[test] + fn build_session_with_fallback_preferred() { + let registry = ProviderRegistry::discover(); + let config = ModelConfig { + model_path: std::path::PathBuf::from("test.onnx"), + tokenizer_path: None, + max_new_tokens: 1, + temperature: 0.0, + dry_run: true, + vitis_config: None, + }; + let result = + registry.build_session_with_fallback(&config, &BackendOptions::new(), Some("CPU")); + assert!(result.is_ok()); + let (name, _) = result.unwrap(); + assert_eq!(name, "CPU"); + } +} diff --git a/inference_bridge/src/lib.rs b/inference_bridge/src/lib.rs index 897ba23..9c92930 100644 --- a/inference_bridge/src/lib.rs +++ b/inference_bridge/src/lib.rs @@ -1,3 +1,4 @@ +pub mod backend; pub mod onnx_vitis; use std::path::PathBuf; diff --git a/mkdocs.yml b/mkdocs.yml index cf6e454..69b1fe2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -39,6 +39,7 @@ nav: - Live-Mode Operations: live-mode-operations.md - Tool Reference: tool-reference.md - Plugin API: plugin-api.md + - CI/CD Integration: ci-integration.md - Security and Sandbox: security-sandbox.md - Troubleshooting: troubleshooting.md - Release and Operations: