From 81311d9d6c9efd1e7c1c99abeac59e4c4c6c9de4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:40:14 +0000 Subject: [PATCH 01/19] Initial plan From cf4f9ff242a9cca1efe776420c97ebba38203f74 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:01:15 +0000 Subject: [PATCH 02/19] feat: initial scaffold and core implementation of agent-kernel Co-authored-by: dgenio <12731907+dgenio@users.noreply.github.com> --- .github/workflows/ci.yml | 44 ++++ AGENTS.md | 38 +++ CHANGELOG.md | 20 ++ CONTRIBUTING.md | 36 +++ Makefile | 20 ++ README.md | 140 ++++++++++- docs/architecture.md | 70 ++++++ docs/capabilities.md | 49 ++++ docs/context_firewall.md | 64 +++++ docs/integrations.md | 70 ++++++ docs/security.md | 38 +++ examples/basic_cli.py | 144 ++++++++++++ examples/billing_demo.py | 157 +++++++++++++ examples/http_driver_demo.py | 137 +++++++++++ pyproject.toml | 55 +++++ src/agent_kernel/__init__.py | 135 +++++++++++ src/agent_kernel/drivers/__init__.py | 7 + src/agent_kernel/drivers/base.py | 42 ++++ src/agent_kernel/drivers/http.py | 120 ++++++++++ src/agent_kernel/drivers/memory.py | 171 ++++++++++++++ src/agent_kernel/enums.py | 32 +++ src/agent_kernel/errors.py | 59 +++++ src/agent_kernel/firewall/__init__.py | 8 + src/agent_kernel/firewall/budgets.py | 26 ++ src/agent_kernel/firewall/redaction.py | 101 ++++++++ src/agent_kernel/firewall/summarize.py | 115 +++++++++ src/agent_kernel/firewall/transform.py | 204 ++++++++++++++++ src/agent_kernel/handles.py | 188 +++++++++++++++ src/agent_kernel/kernel.py | 313 +++++++++++++++++++++++++ src/agent_kernel/models.py | 234 ++++++++++++++++++ src/agent_kernel/policy.py | 135 +++++++++++ src/agent_kernel/py.typed | 0 src/agent_kernel/registry.py | 124 ++++++++++ src/agent_kernel/router.py | 65 +++++ src/agent_kernel/tokens.py | 254 ++++++++++++++++++++ src/agent_kernel/trace.py | 46 ++++ tests/conftest.py | 172 ++++++++++++++ tests/test_drivers.py | 185 +++++++++++++++ tests/test_firewall.py | 157 +++++++++++++ tests/test_handles.py | 108 +++++++++ tests/test_kernel.py | 217 +++++++++++++++++ tests/test_models.py | 154 ++++++++++++ tests/test_policy.py | 194 +++++++++++++++ tests/test_registry.py | 106 +++++++++ tests/test_router.py | 39 +++ tests/test_tokens.py | 105 +++++++++ tests/test_trace.py | 58 +++++ 47 files changed, 4955 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/ci.yml create mode 100644 AGENTS.md create mode 100644 CHANGELOG.md create mode 100644 CONTRIBUTING.md create mode 100644 Makefile create mode 100644 docs/architecture.md create mode 100644 docs/capabilities.md create mode 100644 docs/context_firewall.md create mode 100644 docs/integrations.md create mode 100644 docs/security.md create mode 100644 examples/basic_cli.py create mode 100644 examples/billing_demo.py create mode 100644 examples/http_driver_demo.py create mode 100644 pyproject.toml create mode 100644 src/agent_kernel/__init__.py create mode 100644 src/agent_kernel/drivers/__init__.py create mode 100644 src/agent_kernel/drivers/base.py create mode 100644 src/agent_kernel/drivers/http.py create mode 100644 src/agent_kernel/drivers/memory.py create mode 100644 src/agent_kernel/enums.py create mode 100644 src/agent_kernel/errors.py create mode 100644 src/agent_kernel/firewall/__init__.py create mode 100644 src/agent_kernel/firewall/budgets.py create mode 100644 src/agent_kernel/firewall/redaction.py create mode 100644 src/agent_kernel/firewall/summarize.py create mode 100644 src/agent_kernel/firewall/transform.py create mode 100644 src/agent_kernel/handles.py create mode 100644 src/agent_kernel/kernel.py create mode 100644 src/agent_kernel/models.py create mode 100644 src/agent_kernel/policy.py create mode 100644 src/agent_kernel/py.typed create mode 100644 src/agent_kernel/registry.py create mode 100644 src/agent_kernel/router.py create mode 100644 src/agent_kernel/tokens.py create mode 100644 src/agent_kernel/trace.py create mode 100644 tests/conftest.py create mode 100644 tests/test_drivers.py create mode 100644 tests/test_firewall.py create mode 100644 tests/test_handles.py create mode 100644 tests/test_kernel.py create mode 100644 tests/test_models.py create mode 100644 tests/test_policy.py create mode 100644 tests/test_registry.py create mode 100644 tests/test_router.py create mode 100644 tests/test_tokens.py create mode 100644 tests/test_trace.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..fc901e6 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,44 @@ +name: CI + +on: + push: + branches: ["main", "copilot/**"] + pull_request: + branches: ["main"] + +jobs: + test: + name: "Python ${{ matrix.python-version }}" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip install -e ".[dev]" + + - name: Lint (ruff check) + run: ruff check src/ tests/ examples/ + + - name: Format check (ruff format) + run: ruff format --check src/ tests/ examples/ + + - name: Type check (mypy) + run: mypy src/ + + - name: Test (pytest) + run: python -m pytest -q --cov=agent_kernel --cov-report=term-missing + + - name: Examples + run: | + python examples/basic_cli.py + python examples/billing_demo.py + python examples/http_driver_demo.py diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..589fff4 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,38 @@ +# AGENTS.md — AI Agent Instructions + +This file provides instructions for AI coding agents (Copilot, Cursor, etc.) working in this repository. + +## Repo layout + +``` +src/agent_kernel/ — library source (one module per concern, ≤300 lines each) +tests/ — pytest test suite +examples/ — runnable demos (no internet required) +docs/ — architecture and security documentation +``` + +## Quality bar + +- `make ci` must pass before every commit. +- All public interfaces need type hints and docstrings. +- Use custom exceptions from `errors.py` — never bare `ValueError` or `KeyError`. +- Keep modules ≤ 300 lines. Split if needed. +- No randomness in matching, routing, or summarization (deterministic outputs). + +## Security rules + +- Never log or print secret key material. +- HMAC secrets come from `AGENT_KERNEL_SECRET` env var; fall back to a random dev secret with a logged warning. +- Tokens are tamper-evident (HMAC-SHA256) but not encrypted — document this. +- Confused-deputy prevention: tokens bind to `principal_id + capability_id + constraints`. + +## Adding a new capability driver + +1. Implement the `Driver` protocol in `src/agent_kernel/drivers/`. +2. Register it with `StaticRouter` or implement a custom `Router`. +3. Add integration tests in `tests/test_drivers.py`. + +## Adding a new policy rule + +1. Add the rule to `DefaultPolicyEngine.evaluate()` in `policy.py`. +2. Cover it with a test in `tests/test_policy.py`. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..56e7c36 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,20 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.0] - 2024-01-01 + +### Added +- Initial scaffold: `CapabilityRegistry`, `PolicyEngine`, `HMACTokenProvider`, `Kernel`. +- `InMemoryDriver` and `HTTPDriver` (httpx-based). +- Context `Firewall` with `Budgets`, redaction, and summarization. +- `HandleStore` with TTL, pagination, field selection, and basic filtering. +- `TraceStore` and `explain()` for full audit trail. +- Examples: `basic_cli.py`, `billing_demo.py`, `http_driver_demo.py`. +- Documentation: architecture, security model, integrations, capabilities, context firewall. +- CI pipeline for Python 3.10, 3.11, 3.12 with ruff + mypy + pytest. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..677e1dc --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,36 @@ +# Contributing to agent-kernel + +Thank you for your interest in contributing! + +## Development setup + +```bash +git clone https://github.com/dgenio/agent-kernel.git +cd agent-kernel +python -m venv .venv +source .venv/bin/activate +pip install -e ".[dev]" +``` + +## Running checks + +```bash +make fmt # auto-format with ruff +make lint # lint with ruff +make type # type-check with mypy +make test # run pytest with coverage +make ci # all of the above + examples +``` + +## Pull request guidelines + +1. Keep PRs focused — one logical change per PR. +2. Add or update tests for every behaviour change. +3. All checks in `make ci` must pass. +4. Follow the existing code style (ruff-enforced). +5. Write docstrings on all public interfaces. + +## Security + +Please report security vulnerabilities privately via GitHub Security Advisories. +Do **not** open a public issue for a security bug. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3a079f6 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: fmt lint type test example ci + +fmt: + ruff format src/ tests/ examples/ + +lint: + ruff check src/ tests/ examples/ + +type: + mypy src/ + +test: + python -m pytest -q --cov=agent_kernel + +example: + python examples/basic_cli.py + python examples/billing_demo.py + python examples/http_driver_demo.py + +ci: fmt lint type test example diff --git a/README.md b/README.md index f751c75..7fb1600 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,140 @@ # agent-kernel -Python library implementing a capability-based security kernel for AI agents operating in large tool ecosystems (MCP, A2A). Provides capability tokens, HMAC-signed authorization, policy engine, context firewall with budget enforcement, and pluggable drivers — so agents can safely use 1000+ tools without context blowup. + +[![CI](https://github.com/dgenio/agent-kernel/actions/workflows/ci.yml/badge.svg)](https://github.com/dgenio/agent-kernel/actions/workflows/ci.yml) +[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/) +[![License: Apache 2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) + +A capability-based security kernel for AI agents operating in large tool ecosystems (MCP, A2A, 1000+ tools). + +## 30-second pitch + +Modern AI agents face three hard problems when given access to hundreds or thousands of tools: + +1. **Context blowup** — raw tool output floods the LLM context window. +2. **Tool-space interference** — agents accidentally invoke the wrong tool or escalate privileges. +3. **No audit trail** — there's no record of what ran, when, and why. + +`agent-kernel` solves all three with a thin, composable layer that sits above your tool execution layer: + +- **Capability Tokens** — HMAC-signed, time-bounded, principal-scoped. No token → no execution. +- **Policy Engine** — READ/WRITE/DESTRUCTIVE safety classes + PII/PCI sensitivity handling. +- **Context Firewall** — raw driver output is *never* returned to the LLM; always a bounded `Frame`. +- **Audit Trail** — every invocation creates an `ActionTrace` retrievable via `kernel.explain()`. + +## Architecture + +```mermaid +graph LR + LLM["LLM / Agent"] -->|goal| K["Kernel"] + K -->|search| REG["Registry"] + K -->|evaluate| POL["Policy Engine"] + K -->|sign| TOK["HMAC Token"] + K -->|route| DRV["Driver (MCP/HTTP/Memory)"] + DRV -->|RawResult| FW["Context Firewall"] + FW -->|Frame| LLM + K -->|record| AUD["Audit Trace"] +``` + +## Quickstart + +```bash +pip install agent-kernel +``` + +```python +import asyncio, os +os.environ["AGENT_KERNEL_SECRET"] = "my-secret" + +from agent_kernel import ( + Capability, CapabilityRegistry, HMACTokenProvider, + InMemoryDriver, Kernel, Principal, SafetyClass, StaticRouter, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.models import CapabilityRequest + +# 1. Register a capability +registry = CapabilityRegistry() +registry.register(Capability( + capability_id="tasks.list", + name="List Tasks", + description="List all tasks", + safety_class=SafetyClass.READ, + tags=["tasks", "list"], +)) + +# 2. Wire up a driver +driver = InMemoryDriver() +driver.register_handler("tasks.list", lambda ctx: [{"id": 1, "title": "Buy milk"}]) + +# 3. Build the kernel +kernel = Kernel(registry=registry, router=StaticRouter(routes={"tasks.list": ["memory"]})) +kernel.register_driver(driver) + +async def main(): + principal = Principal(principal_id="alice", roles=["reader"]) + + # 4. Discover → grant → invoke → expand → explain + token = kernel.get_token( + CapabilityRequest(capability_id="tasks.list", goal="list tasks"), + principal, justification="", + ) + frame = await kernel.invoke(token, principal=principal, args={}) + print(frame.facts) # ['Total rows: 1', 'Top keys: id, title', ...] + print(frame.handle) # Handle(handle_id='...', ...) + + expanded = kernel.expand(frame.handle, query={"limit": 1, "fields": ["title"]}) + print(expanded.table_preview) # [{'title': 'Buy milk'}] + + trace = kernel.explain(frame.action_id) + print(trace.driver_id) # 'memory' + +asyncio.run(main()) +``` + +## Where it fits + +``` +┌─────────────────────────────────────────────┐ +│ LLM / Agent loop │ +├─────────────────────────────────────────────┤ +│ agent-kernel ← you are here │ +│ (registry · policy · tokens · firewall) │ +├────────────────┬────────────────────────────┤ +│ contextweaver │ tool execution layer │ +│ (context │ (MCP · HTTP · A2A · │ +│ compilation) │ internal APIs) │ +└────────────────┴────────────────────────────┘ +``` + +`agent-kernel` sits **above** `contextweaver` (context compilation) and **above** raw tool execution. It provides the authorization, execution, and audit layer. + +## Security disclaimers + +> **v0.1 is not production-hardened for real authentication.** + +- HMAC tokens are tamper-evident (SHA-256) but **not encrypted**. Do not put sensitive data in token fields. +- Set `AGENT_KERNEL_SECRET` to a strong random value in production. If unset, a random dev secret is generated per-process with a warning. +- PII redaction is heuristic (regex). It is not a substitute for proper data governance. +- See [docs/security.md](docs/security.md) for the full threat model. + +## Documentation + +- [Architecture](docs/architecture.md) +- [Security model](docs/security.md) +- [Integrations (MCP, HTTPDriver)](docs/integrations.md) +- [Designing capabilities](docs/capabilities.md) +- [Context Firewall](docs/context_firewall.md) + +## Development + +```bash +git clone https://github.com/dgenio/agent-kernel +cd agent-kernel +pip install -e ".[dev]" +make ci # fmt + lint + type + test + examples +``` + +## License + +Apache-2.0 — see [LICENSE](LICENSE). + diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..b2c7fe6 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,70 @@ +# Architecture + +## Overview + +`agent-kernel` is a capability-based security kernel that sits **above** raw tool execution (MCP, HTTP APIs, internal services) and **below** the LLM context window. + +```mermaid +graph TD + LLM["LLM / Agent"] -->|goal text| K["Kernel"] + K -->|search| REG["CapabilityRegistry"] + REG -->|CapabilityRequest| K + K -->|evaluate| POL["PolicyEngine"] + POL -->|PolicyDecision| K + K -->|issue| TOK["TokenProvider (HMAC)"] + TOK -->|CapabilityToken| K + K -->|route| ROU["Router"] + ROU -->|RoutePlan| K + K -->|execute| DRV["Driver (Memory / HTTP / MCP)"] + DRV -->|RawResult| K + K -->|transform| FW["Firewall"] + FW -->|Frame| K + K -->|store| HS["HandleStore"] + K -->|record| TS["TraceStore"] + K -->|Frame| LLM +``` + +## Components + +### Kernel +The central orchestrator. Wires all components together and exposes five methods: +- `request_capabilities(goal)` — discover relevant capabilities +- `grant_capability(request, principal, justification)` — policy check + token issuance +- `invoke(token, principal, args, response_mode)` — execute + firewall + trace +- `expand(handle, query)` — paginate/filter stored results +- `explain(action_id)` — retrieve audit trace + +### CapabilityRegistry +A flat dict of `Capability` objects indexed by `capability_id`. Provides keyword-based search (no LLM, no vector DB — purely token overlap scoring). + +### PolicyEngine +The `DefaultPolicyEngine` implements role-based rules: +1. **READ** — always allowed +2. **WRITE** — requires `justification ≥ 15 chars` + role `writer|admin` +3. **DESTRUCTIVE** — requires role `admin` +4. **PII/PCI** — requires `tenant` attribute; enforces `allowed_fields` unless `pii_reader` +5. **max_rows** — 50 (user), 500 (service) + +### TokenProvider (HMAC) +Issues HMAC-SHA256 signed tokens. Each token is bound to `principal_id + capability_id + constraints`. Verification checks: expiry → signature → principal → capability. + +### Router +`StaticRouter` maps `capability_id → [driver_id, ...]`. First driver that succeeds wins; others are tried as fallbacks. + +### Drivers +- **InMemoryDriver** — Python callables, used for tests and demos +- **HTTPDriver** — `httpx`-based async HTTP client +- (Future) **MCPDriver** — adapter for Model Context Protocol tool servers + +### Firewall +Transforms `RawResult → Frame`. Never exposes raw output to the LLM. +- Four response modes: `summary`, `table`, `handle_only`, `raw` +- Enforces `Budgets` (max_rows, max_fields, max_chars, max_depth) +- Redacts sensitive fields and inline PII patterns +- Deterministic summarisation (no LLM) + +### HandleStore +Stores full results by opaque handle ID with TTL. `expand()` supports pagination, field selection, and basic equality filtering. + +### TraceStore +Records every `ActionTrace`. `explain(action_id)` returns the full audit record. diff --git a/docs/capabilities.md b/docs/capabilities.md new file mode 100644 index 0000000..2863d88 --- /dev/null +++ b/docs/capabilities.md @@ -0,0 +1,49 @@ +# Designing Capabilities + +## Naming conventions + +- Use `domain.verb_noun` format: `billing.list_invoices`, `users.get_profile`. +- Be specific: prefer `billing.cancel_invoice` over `billing.update`. +- Avoid generic names like `billing.execute` or `api.call`. + +## Granularity + +Each capability should map to a single, auditable action with clear side-effects. + +**Good:** +- `billing.list_invoices` (READ, no side-effects) +- `billing.send_reminder` (WRITE, sends an email) +- `billing.void_invoice` (DESTRUCTIVE, irreversible) + +**Avoid:** +- `billing.do_stuff` (too broad) +- `billing.list_or_update_invoices` (mixed safety classes) + +## Safety classes + +| Class | Examples | Policy | +|-------|---------|--------| +| READ | list, get, search, summarize | Always allowed | +| WRITE | create, update, send, approve | Justification + writer role | +| DESTRUCTIVE | delete, void, purge, terminate | Admin role only | + +## Sensitivity tags + +Use `SensitivityTag.PII` when results may contain: name, email, phone, SSN, address. +Use `SensitivityTag.PCI` when results may contain: card numbers, CVV, bank details. +Use `SensitivityTag.SECRETS` when results may contain: API keys, passwords, tokens. + +Always pair sensitivity tags with `allowed_fields` to restrict which fields are returned +to non-privileged callers. + +## Tags + +Add descriptive tags to improve keyword matching: + +```python +Capability( + capability_id="billing.list_invoices", + tags=["billing", "invoices", "list", "finance", "accounts receivable"], + ... +) +``` diff --git a/docs/context_firewall.md b/docs/context_firewall.md new file mode 100644 index 0000000..9d14d99 --- /dev/null +++ b/docs/context_firewall.md @@ -0,0 +1,64 @@ +# Context Firewall + +## Why it exists + +Large tool ecosystems produce large, verbose outputs. Passing raw tool output to an LLM +causes context blowup, leaks PII, and makes the agent unpredictable. The firewall +transforms every `RawResult` into a bounded `Frame` before the LLM sees it. + +## Budgets + +```python +from agent_kernel.firewall.budgets import Budgets + +Budgets( + max_rows=50, # max rows in table_preview + max_fields=20, # max fields per row + max_chars=4000, # total characters across all facts + max_depth=3, # recursion depth for nested structures +) +``` + +## Response modes + +| Mode | What you get | When to use | +|------|-------------|-------------| +| `summary` | ≤20 fact strings + handle | Default; best for LLM context | +| `table` | ≤max_rows dicts + handle | When the LLM needs tabular data | +| `handle_only` | handle + warnings | Defer all data to an expand() call | +| `raw` | Full data (admin only) | Debugging; never for LLM context | + +## Handles + +A `Handle` is an opaque reference to the full dataset stored server-side. + +```python +# Stored automatically on every invoke() +handle = frame.handle + +# Expand with pagination +expanded = kernel.expand(handle, query={"offset": 10, "limit": 5}) + +# Field selection +expanded = kernel.expand(handle, query={"fields": ["id", "name"]}) + +# Basic filtering +expanded = kernel.expand(handle, query={"filter": {"status": "unpaid"}}) +``` + +## Redaction + +When a capability has `SensitivityTag.PII` or `SensitivityTag.PCI`: +- Fields in `Capability.allowed_fields` are kept (others removed) +- Sensitive field names (`email`, `phone`, `card_number`, `ssn`, etc.) are replaced with `[REDACTED]` +- Inline patterns in string values (email addresses, phone numbers, SSNs, card numbers) are redacted + +Principals with the `pii_reader` role bypass `allowed_fields` enforcement. + +## Summarization + +Summaries are produced deterministically: +- **list of dicts** → row count + top keys + numeric stats + categorical distributions +- **dict** → key list + per-value type/value +- **string** → truncated to 500 chars +- **other** → repr() truncated to 200 chars diff --git a/docs/integrations.md b/docs/integrations.md new file mode 100644 index 0000000..4a3533b --- /dev/null +++ b/docs/integrations.md @@ -0,0 +1,70 @@ +# Integrations + +## MCP (Model Context Protocol) + +To integrate with an MCP server, implement a custom driver that wraps the MCP client: + +```python +from agent_kernel.drivers.base import Driver, ExecutionContext +from agent_kernel.models import RawResult + +class MCPDriver: + def __init__(self, mcp_client, driver_id: str = "mcp"): + self._client = mcp_client + self._driver_id = driver_id + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + operation = ctx.args.get("operation", ctx.capability_id) + result = await self._client.call_tool(operation, ctx.args) + return RawResult(capability_id=ctx.capability_id, data=result) +``` + +Then register it: + +```python +kernel.register_driver(MCPDriver(mcp_client)) +router.add_route("mcp.my_tool", ["mcp"]) +``` + +## HTTPDriver + +The built-in `HTTPDriver` supports GET, POST, PUT, DELETE: + +```python +from agent_kernel.drivers.http import HTTPDriver, HTTPEndpoint + +driver = HTTPDriver(driver_id="my_api") +driver.register_endpoint("users.list", HTTPEndpoint( + url="https://api.example.com/users", + method="GET", + headers={"Authorization": "Bearer ..."}, +)) +kernel.register_driver(driver) +``` + +## Custom drivers + +Any object implementing the `Driver` protocol can be registered: + +```python +class Driver(Protocol): + @property + def driver_id(self) -> str: ... + async def execute(self, ctx: ExecutionContext) -> RawResult: ... +``` + +## Capability mapping + +When mapping MCP tools to capabilities, prefer task-shaped names: + +| MCP tool | Capability ID | Safety class | +|----------|--------------|--------------| +| `list_files` | `fs.list_files` | READ | +| `read_file` | `fs.read_file` | READ | +| `write_file` | `fs.write_file` | WRITE | +| `delete_file` | `fs.delete_file` | DESTRUCTIVE | +| `execute_code` | `sandbox.run_code` | DESTRUCTIVE | diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000..06a03f4 --- /dev/null +++ b/docs/security.md @@ -0,0 +1,38 @@ +# Security Model + +## Threat model + +| Threat | Mitigation | +|--------|-----------| +| Tool-space interference (agent calls wrong tool) | Capability registry + policy gate before any execution | +| Confused deputy attack | Tokens are bound to `principal_id` — cannot be reused by another principal | +| Token forgery / tampering | HMAC-SHA256 signature; any bit flip → `TokenInvalid` | +| Token replay after expiry | Expiry checked on every `verify()` call | +| Context injection via raw tool output | Firewall always transforms `RawResult → Frame`; raw data never reaches LLM by default | +| PII / PCI leakage | Redaction + `allowed_fields` enforcement in the firewall | +| Privilege escalation via WRITE/DESTRUCTIVE | Policy engine enforces role requirements | +| Audit evasion | Every `invoke()` creates an immutable `ActionTrace` | + +## Token scopes + +A `CapabilityToken` binds: +- `capability_id` — which capability is authorized +- `principal_id` — who the token was issued to +- `constraints` — max_rows, allowed_fields, etc. (signed into the token) +- `expires_at` — validity window + +Any change to these fields invalidates the HMAC signature. + +## Confused deputy prevention + +Consider an agent that obtains a token for `billing.list_invoices` then passes it to a different agent. The second agent cannot use it because `verify()` checks that `token.principal_id == expected_principal_id`. + +## Security disclaimers + +> **v0.1 is not production-hardened for real authentication.** + +- HMAC tokens are tamper-evident but **not encrypted**. Do not put sensitive data in token fields. +- The `AGENT_KERNEL_SECRET` must be kept secret. Rotate it if compromised. +- The default `InMemoryDriver` has no persistence — suitable for testing only. +- PII redaction is heuristic (regex-based). It is not a substitute for proper data governance. +- There is no rate limiting or quota enforcement in v0.1. diff --git a/examples/basic_cli.py b/examples/basic_cli.py new file mode 100644 index 0000000..507176f --- /dev/null +++ b/examples/basic_cli.py @@ -0,0 +1,144 @@ +"""basic_cli.py — Full flow: request → grant → invoke → expand → explain. + +Run with: python examples/basic_cli.py +""" + +from __future__ import annotations + +import asyncio +import os + +# Use a stable test secret so the example is reproducible. +os.environ.setdefault("AGENT_KERNEL_SECRET", "example-secret-do-not-use-in-prod") + +from agent_kernel import ( + Capability, + CapabilityRegistry, + HMACTokenProvider, + InMemoryDriver, + Kernel, + Principal, + SafetyClass, + SensitivityTag, + StaticRouter, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.models import CapabilityRequest, ImplementationRef + + +def build_registry() -> CapabilityRegistry: + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="tasks.list", + name="List Tasks", + description="List all tasks for the current user", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.NONE, + tags=["tasks", "list", "todo"], + impl=ImplementationRef(driver_id="memory", operation="list_tasks"), + ) + ) + registry.register( + Capability( + capability_id="tasks.create", + name="Create Task", + description="Create a new task", + safety_class=SafetyClass.WRITE, + tags=["tasks", "create", "write"], + impl=ImplementationRef(driver_id="memory", operation="create_task"), + ) + ) + return registry + + +def build_driver() -> InMemoryDriver: + driver = InMemoryDriver(driver_id="memory") + + tasks = [{"id": i, "title": f"Task {i}", "done": i % 3 == 0} for i in range(1, 21)] + + def list_tasks(ctx: ExecutionContext) -> list[dict[str, object]]: + return tasks + + def create_task(ctx: ExecutionContext) -> dict[str, object]: + task = {"id": len(tasks) + 1, "title": ctx.args.get("title", "Untitled"), "done": False} + tasks.append(task) + return task + + driver.register_handler("list_tasks", list_tasks) + driver.register_handler("create_task", create_task) + return driver + + +async def main() -> None: + registry = build_registry() + driver = build_driver() + + router = StaticRouter( + routes={ + "tasks.list": ["memory"], + "tasks.create": ["memory"], + } + ) + + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="example-secret-do-not-use-in-prod"), + router=router, + ) + kernel.register_driver(driver) + + reader = Principal( + principal_id="cli-user-001", + roles=["reader"], + attributes={}, + ) + + print("=== Step 1: Discover capabilities ===") + requests = kernel.request_capabilities("list my tasks") + print(f"Found {len(requests)} matching capabilities:") + for req in requests: + print(f" - {req.capability_id}") + + print("\n=== Step 2: Grant (get token) ===") + list_req = CapabilityRequest(capability_id="tasks.list", goal="list my tasks") + token = kernel.get_token(list_req, reader, justification="") + print(f" Token ID: {token.token_id}") + print(f" Expires: {token.expires_at.isoformat()}") + + print("\n=== Step 3: Invoke ===") + frame = await kernel.invoke( + token, + principal=reader, + args={"operation": "list_tasks"}, + response_mode="summary", + ) + print(f" Action ID: {frame.action_id}") + print(f" Mode: {frame.response_mode}") + print(" Facts:") + for fact in frame.facts: + print(f" • {fact}") + + print("\n=== Step 4: Expand handle ===") + if frame.handle: + expanded = kernel.expand( + frame.handle, + query={"offset": 0, "limit": 3, "fields": ["id", "title"]}, + ) + print(" First 3 rows (id + title only):") + for row in expanded.table_preview: + print(f" {row}") + + print("\n=== Step 5: Explain ===") + trace = kernel.explain(frame.action_id) + print(f" Action ID: {trace.action_id}") + print(f" Capability: {trace.capability_id}") + print(f" Principal: {trace.principal_id}") + print(f" Driver: {trace.driver_id}") + print(f" At: {trace.invoked_at.isoformat()}") + + print("\n✓ basic_cli.py complete.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/billing_demo.py b/examples/billing_demo.py new file mode 100644 index 0000000..d0c1008 --- /dev/null +++ b/examples/billing_demo.py @@ -0,0 +1,157 @@ +"""billing_demo.py — InMemoryDriver with billing dataset, budgets, handles, pagination. + +Run with: python examples/billing_demo.py +""" + +from __future__ import annotations + +import asyncio +import os + +os.environ.setdefault("AGENT_KERNEL_SECRET", "example-secret-do-not-use-in-prod") + +from agent_kernel import ( + Capability, + CapabilityRegistry, + Firewall, + HMACTokenProvider, + Kernel, + Principal, + SafetyClass, + SensitivityTag, + StaticRouter, + make_billing_driver, +) +from agent_kernel.firewall.budgets import Budgets +from agent_kernel.models import CapabilityRequest, ImplementationRef + + +def build_registry() -> CapabilityRegistry: + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="billing.list_invoices", + name="List Invoices", + description="List all invoices, optionally filtered by status", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + allowed_fields=["id", "customer_name", "amount", "currency", "status", "date"], + tags=["billing", "invoices", "list"], + impl=ImplementationRef(driver_id="billing", operation="list_invoices"), + ) + ) + registry.register( + Capability( + capability_id="billing.summarize_spend", + name="Summarize Spend", + description="Summarize total spend per currency and status", + safety_class=SafetyClass.READ, + tags=["billing", "summary", "analytics"], + impl=ImplementationRef(driver_id="billing", operation="summarize_spend"), + ) + ) + return registry + + +async def main() -> None: + registry = build_registry() + billing_driver = make_billing_driver() + + router = StaticRouter( + routes={ + "billing.list_invoices": ["billing"], + "billing.summarize_spend": ["billing"], + } + ) + + # Tight budgets to show enforcement + firewall = Firewall(budgets=Budgets(max_rows=5, max_fields=10, max_chars=2000)) + + kernel = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="example-secret-do-not-use-in-prod"), + router=router, + firewall=firewall, + ) + kernel.register_driver(billing_driver) + + # Reader with tenant attribute (required for PII capabilities) + analyst = Principal( + principal_id="analyst-001", + roles=["reader"], + attributes={"tenant": "acme"}, + ) + + print("=== Billing Demo ===\n") + + # ── List invoices (summary mode) ───────────────────────────────────────── + print("--- list_invoices (summary) ---") + token = kernel.get_token( + CapabilityRequest(capability_id="billing.list_invoices", goal="list invoices"), + analyst, + justification="", + ) + frame = await kernel.invoke( + token, + principal=analyst, + args={"operation": "list_invoices"}, + response_mode="summary", + ) + print(f"Facts ({len(frame.facts)}):") + for f in frame.facts: + print(f" • {f}") + if frame.warnings: + print("Warnings:") + for w in frame.warnings[:3]: + print(f" ⚠ {w}") + + # ── Expand: pagination ─────────────────────────────────────────────────── + print("\n--- expand: first 3 rows, id+amount+status ---") + if frame.handle: + expanded = kernel.expand( + frame.handle, + query={"offset": 0, "limit": 3, "fields": ["id", "amount", "status"]}, + ) + for row in expanded.table_preview: + print(f" {row}") + + # ── Expand: filter ─────────────────────────────────────────────────────── + print("\n--- expand: filter overdue ---") + if frame.handle: + overdue = kernel.expand( + frame.handle, + query={"filter": {"status": "overdue"}, "limit": 3, "fields": ["id", "amount"]}, + ) + print(f" Overdue rows returned: {len(overdue.table_preview)}") + for row in overdue.table_preview: + print(f" {row}") + + # ── Summarize spend ────────────────────────────────────────────────────── + print("\n--- summarize_spend ---") + token2 = kernel.get_token( + CapabilityRequest(capability_id="billing.summarize_spend", goal="summarize"), + analyst, + justification="", + ) + frame2 = await kernel.invoke( + token2, + principal=analyst, + args={"operation": "summarize_spend"}, + response_mode="summary", + ) + for f in frame2.facts: + print(f" • {f}") + + # ── Explain ────────────────────────────────────────────────────────────── + print("\n--- explain ---") + trace = kernel.explain(frame2.action_id) + print(f" Action: {trace.action_id}") + print(f" Capability:{trace.capability_id}") + print(f" Driver: {trace.driver_id}") + print(f" At: {trace.invoked_at.isoformat()}") + + print("\n✓ billing_demo.py complete.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/http_driver_demo.py b/examples/http_driver_demo.py new file mode 100644 index 0000000..e23b577 --- /dev/null +++ b/examples/http_driver_demo.py @@ -0,0 +1,137 @@ +"""http_driver_demo.py — Local mini HTTP server + HTTPDriver (no internet needed). + +Starts a tiny HTTP server on localhost, registers it with an HTTPDriver, +and runs a full invoke → explain flow. + +Run with: python examples/http_driver_demo.py +""" + +from __future__ import annotations + +import asyncio +import json +import os +import threading +from http.server import BaseHTTPRequestHandler, HTTPServer + +os.environ.setdefault("AGENT_KERNEL_SECRET", "example-secret-do-not-use-in-prod") + +from agent_kernel import ( + Capability, + CapabilityRegistry, + HMACTokenProvider, + Kernel, + Principal, + SafetyClass, + StaticRouter, +) +from agent_kernel.drivers.http import HTTPDriver, HTTPEndpoint +from agent_kernel.models import CapabilityRequest + +# ── Tiny HTTP server ──────────────────────────────────────────────────────────── + +_PRODUCTS = [{"id": i, "name": f"Product {i}", "price": round(i * 9.99, 2)} for i in range(1, 11)] + + +class _Handler(BaseHTTPRequestHandler): + def do_GET(self) -> None: # noqa: N802 + if self.path.startswith("/products"): + body = json.dumps(_PRODUCTS).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format: str, *args: object) -> None: # noqa: A002 + pass # suppress request logging + + +def _start_server(port: int) -> HTTPServer: + server = HTTPServer(("127.0.0.1", port), _Handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server + + +# ── Demo ──────────────────────────────────────────────────────────────────────── + + +async def main() -> None: + port = 18765 + server = _start_server(port) + + try: + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="catalog.list_products", + name="List Products", + description="List all products in the catalog", + safety_class=SafetyClass.READ, + tags=["catalog", "products", "list"], + ) + ) + + http_driver = HTTPDriver(driver_id="catalog_api") + http_driver.register_endpoint( + "catalog.list_products", + HTTPEndpoint(url=f"http://127.0.0.1:{port}/products", method="GET"), + ) + + router = StaticRouter(routes={"catalog.list_products": ["catalog_api"]}) + token_provider = HMACTokenProvider(secret="example-secret-do-not-use-in-prod") + + kernel = Kernel(registry=registry, router=router, token_provider=token_provider) + kernel.register_driver(http_driver) + + principal = Principal(principal_id="demo-user-001", roles=["reader"]) + + print("=== HTTP Driver Demo ===\n") + + print("--- Discovering capabilities ---") + requests = kernel.request_capabilities("list products in catalog") + for req in requests: + print(f" - {req.capability_id}") + + print("\n--- Invoking catalog.list_products ---") + token = kernel.get_token( + CapabilityRequest(capability_id="catalog.list_products", goal="list products"), + principal, + justification="", + ) + frame = await kernel.invoke( + token, + principal=principal, + args={"operation": "catalog.list_products"}, + response_mode="summary", + ) + print(f" Mode: {frame.response_mode}") + print(" Facts:") + for fact in frame.facts: + print(f" • {fact}") + + print("\n--- Expanding first 3 products ---") + if frame.handle: + expanded = kernel.expand( + frame.handle, + query={"limit": 3, "fields": ["id", "name", "price"]}, + ) + for row in expanded.table_preview: + print(f" {row}") + + print("\n--- Explain ---") + trace = kernel.explain(frame.action_id) + print(f" Driver: {trace.driver_id}") + print(f" At: {trace.invoked_at.isoformat()}") + + print("\n✓ http_driver_demo.py complete.") + finally: + server.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4a22d5d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,55 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "agent-kernel" +version = "0.1.0" +description = "Capability-based security kernel for AI agents operating in large tool ecosystems" +readme = "README.md" +license = { file = "LICENSE" } +requires-python = ">=3.10" +authors = [{ name = "agent-kernel contributors" }] +keywords = ["ai", "agents", "mcp", "security", "capabilities", "llm"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Security", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = ["httpx>=0.27"] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-cov>=5.0", + "pytest-asyncio>=0.23", + "ruff>=0.4", + "mypy>=1.10", + "httpx>=0.27", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/agent_kernel"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +line-length = 99 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B", "SIM"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.10" +strict = true +files = ["src/"] diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py new file mode 100644 index 0000000..acc22e6 --- /dev/null +++ b/src/agent_kernel/__init__.py @@ -0,0 +1,135 @@ +"""agent-kernel: capability-based security kernel for AI agents. + +Public API +---------- + +Core classes:: + + from agent_kernel import Kernel, CapabilityRegistry + from agent_kernel import Capability, Principal + from agent_kernel import SafetyClass, SensitivityTag + +Token management:: + + from agent_kernel import HMACTokenProvider, CapabilityToken + +Policy:: + + from agent_kernel import DefaultPolicyEngine + +Firewall:: + + from agent_kernel import Firewall, Budgets + +Handles & traces:: + + from agent_kernel import HandleStore, TraceStore + +Errors:: + + from agent_kernel import ( + AgentKernelError, + TokenExpired, TokenInvalid, TokenScopeError, + PolicyDenied, DriverError, FirewallError, + CapabilityNotFound, HandleNotFound, HandleExpired, + ) +""" + +from .drivers.base import Driver, ExecutionContext +from .drivers.http import HTTPDriver +from .drivers.memory import InMemoryDriver, make_billing_driver +from .enums import SafetyClass, SensitivityTag +from .errors import ( + AgentKernelError, + CapabilityNotFound, + DriverError, + FirewallError, + HandleExpired, + HandleNotFound, + PolicyDenied, + TokenExpired, + TokenInvalid, + TokenScopeError, +) +from .firewall.budgets import Budgets +from .firewall.transform import Firewall +from .handles import HandleStore +from .kernel import Kernel +from .models import ( + ActionTrace, + Capability, + CapabilityGrant, + CapabilityRequest, + Frame, + Handle, + ImplementationRef, + PolicyDecision, + Principal, + Provenance, + RawResult, + ResponseMode, + RoutePlan, +) +from .policy import DefaultPolicyEngine +from .registry import CapabilityRegistry +from .router import StaticRouter +from .tokens import CapabilityToken, HMACTokenProvider +from .trace import TraceStore + +__version__ = "0.1.0" + +__all__ = [ + # version + "__version__", + # kernel + "Kernel", + # registry + "CapabilityRegistry", + # models + "Capability", + "CapabilityGrant", + "CapabilityRequest", + "CapabilityToken", + "Frame", + "Handle", + "ImplementationRef", + "PolicyDecision", + "Principal", + "Provenance", + "RawResult", + "ResponseMode", + "RoutePlan", + "ActionTrace", + # enums + "SafetyClass", + "SensitivityTag", + # errors + "AgentKernelError", + "CapabilityNotFound", + "DriverError", + "FirewallError", + "HandleExpired", + "HandleNotFound", + "PolicyDenied", + "TokenExpired", + "TokenInvalid", + "TokenScopeError", + # policy + "DefaultPolicyEngine", + # tokens + "HMACTokenProvider", + # router + "StaticRouter", + # drivers + "Driver", + "ExecutionContext", + "InMemoryDriver", + "HTTPDriver", + "make_billing_driver", + # firewall + "Firewall", + "Budgets", + # stores + "HandleStore", + "TraceStore", +] diff --git a/src/agent_kernel/drivers/__init__.py b/src/agent_kernel/drivers/__init__.py new file mode 100644 index 0000000..6163c73 --- /dev/null +++ b/src/agent_kernel/drivers/__init__.py @@ -0,0 +1,7 @@ +"""Driver sub-package exports.""" + +from .base import Driver, ExecutionContext +from .http import HTTPDriver +from .memory import InMemoryDriver + +__all__ = ["Driver", "ExecutionContext", "HTTPDriver", "InMemoryDriver"] diff --git a/src/agent_kernel/drivers/base.py b/src/agent_kernel/drivers/base.py new file mode 100644 index 0000000..7b43235 --- /dev/null +++ b/src/agent_kernel/drivers/base.py @@ -0,0 +1,42 @@ +"""Base driver protocol and execution context.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol + +from ..models import RawResult + + +@dataclass(slots=True) +class ExecutionContext: + """Runtime context passed to a driver when executing a capability.""" + + capability_id: str + principal_id: str + args: dict[str, Any] = field(default_factory=dict) + constraints: dict[str, Any] = field(default_factory=dict) + action_id: str = "" + + +class Driver(Protocol): + """Interface for capability execution drivers.""" + + @property + def driver_id(self) -> str: + """Unique identifier for this driver instance.""" + ... + + async def execute(self, ctx: ExecutionContext) -> RawResult: + """Execute a capability and return a raw result. + + Args: + ctx: Execution context including capability ID, args, and constraints. + + Returns: + The unfiltered :class:`RawResult` from the underlying system. + + Raises: + DriverError: If execution fails. + """ + ... diff --git a/src/agent_kernel/drivers/http.py b/src/agent_kernel/drivers/http.py new file mode 100644 index 0000000..d586c01 --- /dev/null +++ b/src/agent_kernel/drivers/http.py @@ -0,0 +1,120 @@ +"""HTTPDriver: execute capabilities against HTTP APIs using httpx.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import httpx + +from ..errors import DriverError +from ..models import RawResult +from .base import ExecutionContext + + +@dataclass +class HTTPEndpoint: + """Describes an HTTP endpoint for a capability operation.""" + + url: str + method: str = "GET" + headers: dict[str, str] = field(default_factory=dict) + timeout: float = 30.0 + + +class HTTPDriver: + """A driver that invokes capabilities via HTTP using :mod:`httpx`. + + Each operation must be registered with an :class:`HTTPEndpoint`. + The driver performs *synchronous* execution inside an async method by + using ``httpx.AsyncClient`` for proper async support. + """ + + def __init__( + self, + driver_id: str = "http", + *, + base_headers: dict[str, str] | None = None, + default_timeout: float = 30.0, + ) -> None: + self._driver_id = driver_id + self._endpoints: dict[str, HTTPEndpoint] = {} + self._base_headers = base_headers or {} + self._default_timeout = default_timeout + + @property + def driver_id(self) -> str: + """Unique identifier for this driver.""" + return self._driver_id + + def register_endpoint(self, operation: str, endpoint: HTTPEndpoint) -> None: + """Register an HTTP endpoint for an operation. + + Args: + operation: The operation name to handle. + endpoint: The :class:`HTTPEndpoint` configuration. + """ + self._endpoints[operation] = endpoint + + async def execute(self, ctx: ExecutionContext) -> RawResult: + """Execute an HTTP request for the given context. + + The operation is resolved from ``ctx.args.get("operation")`` first, + then falls back to ``ctx.capability_id``. + + Args: + ctx: The execution context. + + Returns: + :class:`RawResult` containing the parsed JSON response. + + Raises: + DriverError: If the endpoint is not registered or the request fails. + """ + operation = str(ctx.args.get("operation", ctx.capability_id)) + endpoint = self._endpoints.get(operation) + if endpoint is None: + raise DriverError( + f"HTTPDriver '{self._driver_id}' has no endpoint for operation='{operation}'." + ) + + headers = {**self._base_headers, **endpoint.headers} + params: dict[str, Any] = {} + json_body: dict[str, Any] | None = None + + if endpoint.method.upper() == "GET": + params = {k: v for k, v in ctx.args.items() if k != "operation"} + else: + json_body = {k: v for k, v in ctx.args.items() if k != "operation"} + + try: + async with httpx.AsyncClient(headers=headers, timeout=endpoint.timeout) as client: + if endpoint.method.upper() == "GET": + response = await client.get(endpoint.url, params=params) + elif endpoint.method.upper() == "POST": + response = await client.post(endpoint.url, json=json_body) + elif endpoint.method.upper() == "PUT": + response = await client.put(endpoint.url, json=json_body) + elif endpoint.method.upper() == "DELETE": + response = await client.delete(endpoint.url, params=params) + else: + response = await client.request( + endpoint.method.upper(), endpoint.url, json=json_body + ) + response.raise_for_status() + data: Any = response.json() + except httpx.HTTPStatusError as exc: + raise DriverError( + f"HTTPDriver '{self._driver_id}': HTTP {exc.response.status_code} " + f"from {endpoint.url}: {exc.response.text[:200]}" + ) from exc + except httpx.RequestError as exc: + raise DriverError( + f"HTTPDriver '{self._driver_id}': Request to {endpoint.url} failed: {exc}" + ) from exc + + return RawResult( + capability_id=ctx.capability_id, + data=data, + metadata={"status_code": response.status_code, "url": endpoint.url}, + ) diff --git a/src/agent_kernel/drivers/memory.py b/src/agent_kernel/drivers/memory.py new file mode 100644 index 0000000..72da9ce --- /dev/null +++ b/src/agent_kernel/drivers/memory.py @@ -0,0 +1,171 @@ +"""In-memory driver for testing and local demos.""" + +from __future__ import annotations + +import random +from collections.abc import Callable +from typing import Any + +from ..errors import DriverError +from ..models import RawResult +from .base import ExecutionContext + +Handler = Callable[[ExecutionContext], Any] + + +class InMemoryDriver: + """A driver that executes capabilities using registered Python callables. + + This driver is primarily intended for unit tests, demos, and + local development where no external API is available. + """ + + def __init__(self, driver_id: str = "memory") -> None: + self._driver_id = driver_id + self._handlers: dict[str, Handler] = {} + + @property + def driver_id(self) -> str: + """Unique identifier for this driver.""" + return self._driver_id + + def register_handler(self, operation: str, handler: Handler) -> None: + """Register a Python callable as the handler for an operation. + + Args: + operation: The operation name (must match ``ImplementationRef.operation``). + handler: A callable ``(ExecutionContext) -> Any`` that performs the operation. + """ + self._handlers[operation] = handler + + async def execute(self, ctx: ExecutionContext) -> RawResult: + """Execute a capability via its registered handler. + + The operation is looked up from ``ctx.args.get("operation")`` first, + then falls back to ``ctx.capability_id``. + + Args: + ctx: The execution context. + + Returns: + :class:`RawResult` wrapping the handler's return value. + + Raises: + DriverError: If no handler is registered or the handler raises. + """ + operation = str(ctx.args.get("operation", ctx.capability_id)) + handler = self._handlers.get(operation) + if handler is None: + raise DriverError( + f"InMemoryDriver '{self._driver_id}' has no handler for " + f"operation='{operation}'. Register one with register_handler()." + ) + try: + data = handler(ctx) + except Exception as exc: + raise DriverError(f"Handler for operation='{operation}' raised: {exc}") from exc + return RawResult(capability_id=ctx.capability_id, data=data) + + +# ── Billing dataset factory ─────────────────────────────────────────────────── + + +def _make_billing_dataset(n: int = 200) -> list[dict[str, Any]]: + """Generate a deterministic synthetic billing dataset. + + Uses :class:`random.Random` seeded with ``42`` so the output is always + the same regardless of global random state. + + Args: + n: Number of invoice records to generate. + + Returns: + A list of invoice dicts. + """ + rng = random.Random(42) + statuses = ["paid", "unpaid", "overdue"] + currencies = ["USD", "EUR", "GBP"] + first_names = ["Alice", "Bob", "Carol", "Dave", "Eve", "Frank", "Grace", "Hiro"] + last_names = ["Smith", "Jones", "Lee", "Brown", "Taylor", "Wilson", "Davis"] + + records: list[dict[str, Any]] = [] + for i in range(1, n + 1): + fname = rng.choice(first_names) + lname = rng.choice(last_names) + name = f"{fname} {lname}" + email = f"{fname.lower()}.{lname.lower()}{i}@example.com" + phone = f"+1-555-{rng.randint(1000, 9999)}" + amount = round(rng.uniform(10.0, 5000.0), 2) + currency = rng.choice(currencies) + status = rng.choice(statuses) + year = rng.randint(2023, 2024) + month = rng.randint(1, 12) + day = rng.randint(1, 28) + date_str = f"{year}-{month:02d}-{day:02d}" + line_items = [ + { + "description": f"Item {j}", + "qty": rng.randint(1, 5), + "unit_price": round(rng.uniform(5.0, 500.0), 2), + } + for j in range(1, rng.randint(1, 4) + 1) + ] + records.append( + { + "id": f"INV-{i:04d}", + "customer_name": name, + "email": email, + "phone": phone, + "amount": amount, + "currency": currency, + "status": status, + "date": date_str, + "line_items": line_items, + } + ) + return records + + +BILLING_DATASET: list[dict[str, Any]] = _make_billing_dataset() + + +def make_billing_driver() -> InMemoryDriver: + """Return an :class:`InMemoryDriver` pre-loaded with billing operations. + + Operations: + - ``list_invoices`` — returns all invoices (filtered by ``status`` if provided). + - ``get_invoice`` — returns a single invoice by ``id``. + - ``summarize_spend`` — returns total spend per currency/status. + + Returns: + A fully configured :class:`InMemoryDriver`. + """ + driver = InMemoryDriver(driver_id="billing") + + def list_invoices(ctx: ExecutionContext) -> list[dict[str, Any]]: + status_filter = ctx.args.get("status") + data = BILLING_DATASET + if status_filter: + data = [r for r in data if r["status"] == status_filter] + return data + + def get_invoice(ctx: ExecutionContext) -> dict[str, Any] | None: + invoice_id = ctx.args.get("id") + for record in BILLING_DATASET: + if record["id"] == invoice_id: + return record + return None + + def summarize_spend(ctx: ExecutionContext) -> dict[str, Any]: + totals: dict[str, dict[str, float]] = {} + for record in BILLING_DATASET: + cur = record["currency"] + sta = record["status"] + totals.setdefault(cur, {}).setdefault(sta, 0.0) + totals[cur][sta] = round(totals[cur][sta] + record["amount"], 2) + return {"totals": totals, "invoice_count": len(BILLING_DATASET)} + + driver.register_handler("list_invoices", list_invoices) + driver.register_handler("get_invoice", get_invoice) + driver.register_handler("summarize_spend", summarize_spend) + return driver diff --git a/src/agent_kernel/enums.py b/src/agent_kernel/enums.py new file mode 100644 index 0000000..c14da31 --- /dev/null +++ b/src/agent_kernel/enums.py @@ -0,0 +1,32 @@ +"""Enumerations for SafetyClass and SensitivityTag.""" + +from enum import Enum + + +class SafetyClass(str, Enum): + """Classifies the danger level of a capability's side-effects.""" + + READ = "READ" + """No side-effects; safe to retry.""" + + WRITE = "WRITE" + """Mutates state; requires justification and writer/admin role.""" + + DESTRUCTIVE = "DESTRUCTIVE" + """Irreversible; requires admin role.""" + + +class SensitivityTag(str, Enum): + """Tags data sensitivity requirements on a capability.""" + + NONE = "NONE" + """No special sensitivity.""" + + PII = "PII" + """Personally identifiable information (name, email, phone, SSN).""" + + PCI = "PCI" + """Payment card industry data (card numbers, CVV).""" + + SECRETS = "SECRETS" + """Credentials, API keys, tokens.""" diff --git a/src/agent_kernel/errors.py b/src/agent_kernel/errors.py new file mode 100644 index 0000000..31bd70d --- /dev/null +++ b/src/agent_kernel/errors.py @@ -0,0 +1,59 @@ +"""Custom exception hierarchy for agent-kernel.""" + + +class AgentKernelError(Exception): + """Base class for all agent-kernel errors.""" + + +# ── Token errors ────────────────────────────────────────────────────────────── + + +class TokenExpired(AgentKernelError): + """Raised when a token's ``expires_at`` is in the past.""" + + +class TokenInvalid(AgentKernelError): + """Raised when a token's HMAC signature does not verify.""" + + +class TokenScopeError(AgentKernelError): + """Raised when a token is used by the wrong principal or for the wrong capability.""" + + +# ── Policy errors ───────────────────────────────────────────────────────────── + + +class PolicyDenied(AgentKernelError): + """Raised when the policy engine rejects a capability request.""" + + +# ── Driver errors ───────────────────────────────────────────────────────────── + + +class DriverError(AgentKernelError): + """Raised when a driver fails to execute a capability.""" + + +# ── Firewall errors ─────────────────────────────────────────────────────────── + + +class FirewallError(AgentKernelError): + """Raised when the context firewall cannot transform a raw result.""" + + +# ── Registry / lookup errors ────────────────────────────────────────────────── + + +class CapabilityNotFound(AgentKernelError): + """Raised when a capability ID is not found in the registry.""" + + +# ── Handle errors ───────────────────────────────────────────────────────────── + + +class HandleNotFound(AgentKernelError): + """Raised when a handle ID is not found in the handle store.""" + + +class HandleExpired(AgentKernelError): + """Raised when a handle's TTL has elapsed.""" diff --git a/src/agent_kernel/firewall/__init__.py b/src/agent_kernel/firewall/__init__.py new file mode 100644 index 0000000..8912822 --- /dev/null +++ b/src/agent_kernel/firewall/__init__.py @@ -0,0 +1,8 @@ +"""Firewall sub-package exports.""" + +from .budgets import Budgets +from .redaction import redact +from .summarize import summarize +from .transform import Firewall + +__all__ = ["Budgets", "Firewall", "redact", "summarize"] diff --git a/src/agent_kernel/firewall/budgets.py b/src/agent_kernel/firewall/budgets.py new file mode 100644 index 0000000..1bcb343 --- /dev/null +++ b/src/agent_kernel/firewall/budgets.py @@ -0,0 +1,26 @@ +"""Budgets dataclass for the context firewall. + +Re-exported from :mod:`agent_kernel.models` for convenience; also available +directly as ``agent_kernel.firewall.Budgets``. +""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(slots=True) +class Budgets: + """Budget constraints enforced by the context firewall. + + Attributes: + max_rows: Maximum number of rows to include in a table preview. + max_fields: Maximum number of fields per row. + max_chars: Maximum total characters in the frame output. + max_depth: Maximum nesting depth when traversing dict/list values. + """ + + max_rows: int = 50 + max_fields: int = 20 + max_chars: int = 4000 + max_depth: int = 3 diff --git a/src/agent_kernel/firewall/redaction.py b/src/agent_kernel/firewall/redaction.py new file mode 100644 index 0000000..dfc0591 --- /dev/null +++ b/src/agent_kernel/firewall/redaction.py @@ -0,0 +1,101 @@ +"""PII/PCI field redaction for the context firewall.""" + +from __future__ import annotations + +import re +from typing import Any + +# Fields that are always redacted when PII/PCI sensitivity is active +# (unless the principal has the pii_reader role). +_SENSITIVE_FIELDS: frozenset[str] = frozenset( + { + "email", + "phone", + "card_number", + "ssn", + "social_security_number", + "cvv", + "credit_card", + "password", + "secret", + } +) + +_EMAIL_RE = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+") +_PHONE_RE = re.compile(r"\+?[\d\s\-().]{7,}") +_CARD_RE = re.compile(r"\b(?:\d[ -]?){13,16}\b") +_SSN_RE = re.compile(r"\b\d{3}[- ]\d{2}[- ]\d{4}\b") + +_REDACTED = "[REDACTED]" + + +def _is_sensitive_field_name(name: str) -> bool: + return name.lower() in _SENSITIVE_FIELDS + + +def redact( + data: Any, + *, + allowed_fields: list[str] | None = None, + depth: int = 0, + max_depth: int = 3, +) -> tuple[Any, list[str]]: + """Recursively redact sensitive data from *data*. + + If *allowed_fields* is non-empty, only those fields are kept in dicts; + all others are removed. Sensitive field names are replaced with + ``[REDACTED]`` regardless. + + Args: + data: The data to redact. + allowed_fields: If non-empty, only keep these field names in dicts. + depth: Current recursion depth (used internally). + max_depth: Maximum recursion depth. + + Returns: + A tuple of ``(redacted_data, warnings)`` where *warnings* is a list of + human-readable strings describing what was redacted. + """ + warnings: list[str] = [] + + if depth >= max_depth: + return data, warnings + + if isinstance(data, dict): + result: dict[str, Any] = {} + for k, v in data.items(): + if allowed_fields and k not in allowed_fields: + warnings.append(f"Field '{k}' omitted (not in allowed_fields).") + continue + if _is_sensitive_field_name(str(k)): + result[k] = _REDACTED + warnings.append(f"Field '{k}' redacted (sensitive field name).") + else: + child, child_warnings = redact( + v, allowed_fields=None, depth=depth + 1, max_depth=max_depth + ) + result[k] = child + warnings.extend(child_warnings) + return result, warnings + + if isinstance(data, list): + redacted_list = [] + for item in data: + child, child_warnings = redact( + item, allowed_fields=allowed_fields, depth=depth + 1, max_depth=max_depth + ) + redacted_list.append(child) + warnings.extend(child_warnings) + return redacted_list, warnings + + if isinstance(data, str): + original = data + data = _EMAIL_RE.sub(_REDACTED, data) + data = _PHONE_RE.sub(_REDACTED, data) + data = _CARD_RE.sub(_REDACTED, data) + data = _SSN_RE.sub(_REDACTED, data) + if data != original: + warnings.append("String value contained sensitive patterns and was redacted.") + return data, warnings + + return data, warnings diff --git a/src/agent_kernel/firewall/summarize.py b/src/agent_kernel/firewall/summarize.py new file mode 100644 index 0000000..a9f1274 --- /dev/null +++ b/src/agent_kernel/firewall/summarize.py @@ -0,0 +1,115 @@ +"""Deterministic summarization heuristics for the context firewall. + +No LLM is used — summaries are produced by structural analysis of the data. +""" + +from __future__ import annotations + +from typing import Any + + +def summarize(data: Any, *, max_facts: int = 20) -> list[str]: + """Produce a list of human-readable facts from *data*. + + Dispatches to specialised handlers based on the type of *data*: + + - **list of dicts** → count + top keys + basic stats on numeric fields. + - **dict** → keys + per-value type annotations + aggregates. + - **str** → truncated string. + - **other** → ``repr()`` truncated to 200 chars. + + Args: + data: The data to summarise. + max_facts: Maximum number of facts to return. + + Returns: + An ordered list of fact strings (≤ *max_facts*). + """ + if isinstance(data, list) and data and isinstance(data[0], dict): + return _summarize_list_of_dicts(data, max_facts=max_facts) + if isinstance(data, dict): + return _summarize_dict(data, max_facts=max_facts) + if isinstance(data, list): + return _summarize_plain_list(data, max_facts=max_facts) + if isinstance(data, str): + return _summarize_string(data, max_facts=max_facts) + return [repr(data)[:200]] + + +# ── Specialised handlers ────────────────────────────────────────────────────── + + +def _summarize_list_of_dicts(rows: list[dict[str, Any]], *, max_facts: int) -> list[str]: + facts: list[str] = [] + facts.append(f"Total rows: {len(rows)}") + + # Top keys (union of keys in first 10 rows for performance) + key_counts: dict[str, int] = {} + for row in rows[:10]: + for k in row: + key_counts[k] = key_counts.get(k, 0) + 1 + top_keys = sorted(key_counts, key=lambda k: -key_counts[k])[:10] + facts.append(f"Top keys: {', '.join(top_keys)}") + + # Numeric stats + numeric_keys = [ + k for k in top_keys if all(isinstance(r.get(k), (int, float)) for r in rows if k in r) + ] + for k in numeric_keys[:5]: + values = [float(r[k]) for r in rows if k in r] + if values: + facts.append( + f"{k}: min={min(values):.2f}, max={max(values):.2f}, " + f"avg={sum(values) / len(values):.2f}" + ) + if len(facts) >= max_facts: + break + + # Status / categorical counts (string fields with few distinct values) + for k in top_keys[:5]: + if k in numeric_keys: + continue + values_str = [str(r[k]) for r in rows if k in r and isinstance(r[k], str)] + if not values_str: + continue + distinct = sorted(set(values_str)) + if 2 <= len(distinct) <= 10: + counts = {v: values_str.count(v) for v in distinct} + summary = ", ".join(f"{v}={counts[v]}" for v in sorted(counts)) + facts.append(f"{k} distribution: {summary}") + if len(facts) >= max_facts: + break + + return facts[:max_facts] + + +def _summarize_dict(data: dict[str, Any], *, max_facts: int) -> list[str]: + facts: list[str] = [f"Keys: {', '.join(sorted(data.keys())[:20])}"] + for k, v in list(data.items())[: max_facts - 1]: + if isinstance(v, (int, float)): + facts.append(f"{k}: {v}") + elif isinstance(v, str): + facts.append(f"{k}: {v[:80]}") + elif isinstance(v, list): + facts.append(f"{k}: list of {len(v)} items") + elif isinstance(v, dict): + facts.append(f"{k}: dict with keys [{', '.join(list(v.keys())[:5])}]") + else: + facts.append(f"{k}: {repr(v)[:80]}") + if len(facts) >= max_facts: + break + return facts[:max_facts] + + +def _summarize_plain_list(data: list[Any], *, max_facts: int) -> list[str]: + facts = [f"List of {len(data)} items"] + for item in data[: max_facts - 1]: + facts.append(repr(item)[:100]) + return facts[:max_facts] + + +def _summarize_string(data: str, *, max_facts: int) -> list[str]: + truncated = data[:500] + if len(data) > 500: + truncated += f"… ({len(data)} chars total)" + return [truncated][:max_facts] diff --git a/src/agent_kernel/firewall/transform.py b/src/agent_kernel/firewall/transform.py new file mode 100644 index 0000000..f570f63 --- /dev/null +++ b/src/agent_kernel/firewall/transform.py @@ -0,0 +1,204 @@ +"""Context firewall: transforms raw driver output into bounded Frames.""" + +from __future__ import annotations + +import copy +import datetime +import json +from typing import Any + +from ..models import ( + Budgets, + Frame, + Handle, + Provenance, + RawResult, + ResponseMode, +) +from .budgets import Budgets as FirewallBudgets +from .redaction import redact +from .summarize import summarize + +# Use the models.Budgets for the Frame; the firewall.Budgets is re-exported for +# back-compat but they are structurally identical. +_Budgets = Budgets + + +class Firewall: + """Transforms :class:`RawResult` objects into LLM-safe :class:`Frame` objects. + + The firewall enforces: + - Row, field, character, and depth budgets. + - PII/PCI redaction (when sensitivity constraints are set). + - Four response modes: ``summary``, ``table``, ``handle_only``, ``raw``. + """ + + def __init__(self, budgets: _Budgets | FirewallBudgets | None = None) -> None: + if budgets is None: + self._budgets: _Budgets | FirewallBudgets = _Budgets() + else: + self._budgets = budgets + + def transform( + self, + raw: RawResult, + *, + action_id: str, + principal_id: str, + principal_roles: list[str], + response_mode: ResponseMode, + constraints: dict[str, Any] | None = None, + handle: Handle | None = None, + ) -> Frame: + """Transform a raw result into a Frame. + + Args: + raw: The driver output to transform. + action_id: The audit action ID. + principal_id: Principal making the request. + principal_roles: Principal's roles (used for ``raw`` mode gate). + response_mode: How to present the data. + constraints: Active execution constraints (may include ``max_rows``, + ``allowed_fields``). + handle: Pre-created handle for the full dataset. + + Returns: + A bounded :class:`Frame`. + + Raises: + FirewallError: If the raw result cannot be transformed. + """ + constraints = constraints or {} + max_rows = int(constraints.get("max_rows", self._budgets.max_rows)) + allowed_fields: list[str] = list(constraints.get("allowed_fields", [])) + + provenance = Provenance( + capability_id=raw.capability_id, + principal_id=principal_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + action_id=action_id, + ) + + warnings: list[str] = [] + data = raw.data + + # ── Redaction ────────────────────────────────────────────────────────── + needs_redaction = bool(allowed_fields) + if needs_redaction: + data, redact_warnings = redact( + data, + allowed_fields=allowed_fields, + max_depth=self._budgets.max_depth, + ) + warnings.extend(redact_warnings) + else: + # Always run redaction even without allowed_fields to catch inline PII + data, redact_warnings = redact(data, max_depth=self._budgets.max_depth) + warnings.extend(redact_warnings) + + # ── Raw mode (admin only) ────────────────────────────────────────────── + if response_mode == "raw": + if "admin" not in principal_roles: + warnings.append("raw mode requires admin role; falling back to summary.") + response_mode = "summary" + else: + raw_str = _truncate_str(json.dumps(data, default=str), self._budgets.max_chars) + return Frame( + action_id=action_id, + capability_id=raw.capability_id, + response_mode="raw", + raw_data=json.loads(raw_str) if raw_str else data, + handle=handle, + warnings=warnings, + provenance=provenance, + ) + + # ── Handle only ─────────────────────────────────────────────────────── + if response_mode == "handle_only": + return Frame( + action_id=action_id, + capability_id=raw.capability_id, + response_mode="handle_only", + handle=handle, + warnings=warnings, + provenance=provenance, + ) + + # ── Table mode ──────────────────────────────────────────────────────── + if response_mode == "table": + table_preview = self._make_table(data, max_rows=max_rows) + return Frame( + action_id=action_id, + capability_id=raw.capability_id, + response_mode="table", + table_preview=table_preview, + handle=handle, + warnings=warnings, + provenance=provenance, + ) + + # ── Summary mode (default) ──────────────────────────────────────────── + facts = summarize(data, max_facts=20) + # Enforce char budget across all facts + facts = _cap_facts(facts, self._budgets.max_chars) + return Frame( + action_id=action_id, + capability_id=raw.capability_id, + response_mode="summary", + facts=facts, + handle=handle, + warnings=warnings, + provenance=provenance, + ) + + # ── Helpers ─────────────────────────────────────────────────────────────── + + def _make_table(self, data: Any, *, max_rows: int) -> list[dict[str, Any]]: + """Convert *data* to a list of dicts, capped at *max_rows*.""" + if isinstance(data, list): + rows = data[:max_rows] + elif isinstance(data, dict): + rows = [data] + else: + rows = [{"value": data}] + + result: list[dict[str, Any]] = [] + for row in rows: + if isinstance(row, dict): + capped = dict(list(row.items())[: self._budgets.max_fields]) + result.append(capped) + else: + result.append({"value": row}) + return result + + +def _truncate_str(s: str, max_chars: int) -> str: + if len(s) <= max_chars: + return s + return s[:max_chars] + + +def _cap_facts(facts: list[str], max_chars: int) -> list[str]: + """Return as many facts as fit within *max_chars* total.""" + total = 0 + result: list[str] = [] + for fact in facts: + total += len(fact) + if total > max_chars: + break + result.append(fact) + return result + + +def _deep_copy_truncated(data: Any, *, max_depth: int, depth: int = 0) -> Any: + """Deep-copy data, stopping recursion at *max_depth*.""" + if depth >= max_depth: + return repr(data)[:100] + if isinstance(data, dict): + return { + k: _deep_copy_truncated(v, max_depth=max_depth, depth=depth + 1) + for k, v in data.items() + } + if isinstance(data, list): + return [_deep_copy_truncated(v, max_depth=max_depth, depth=depth + 1) for v in data] + return copy.copy(data) if not isinstance(data, (int, float, str, bool, type(None))) else data diff --git a/src/agent_kernel/handles.py b/src/agent_kernel/handles.py new file mode 100644 index 0000000..ed06878 --- /dev/null +++ b/src/agent_kernel/handles.py @@ -0,0 +1,188 @@ +"""HandleStore: in-memory storage for full capability results with TTL.""" + +from __future__ import annotations + +import datetime +import uuid +from typing import Any + +from .errors import HandleExpired, HandleNotFound +from .models import Frame, Handle, Provenance, ResponseMode + + +class HandleStore: + """Stores full capability results by handle ID with TTL-based expiry. + + Entries are evicted lazily (on access) or explicitly via :meth:`evict_expired`. + """ + + def __init__(self, default_ttl_seconds: int = 3600) -> None: + self._default_ttl = default_ttl_seconds + self._data: dict[str, Any] = {} + self._meta: dict[str, Handle] = {} + + # ── Storage ─────────────────────────────────────────────────────────────── + + def store( + self, + capability_id: str, + data: Any, + *, + ttl_seconds: int | None = None, + ) -> Handle: + """Store *data* and return a :class:`Handle`. + + Args: + capability_id: The capability that produced *data*. + data: The full dataset to store. + ttl_seconds: Time-to-live in seconds (defaults to the store default). + + Returns: + A :class:`Handle` referencing the stored data. + """ + ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl + now = datetime.datetime.now(tz=datetime.timezone.utc) + handle = Handle( + handle_id=str(uuid.uuid4()), + capability_id=capability_id, + created_at=now, + expires_at=now + datetime.timedelta(seconds=ttl), + total_rows=len(data) if isinstance(data, list) else 1, + ) + self._data[handle.handle_id] = data + self._meta[handle.handle_id] = handle + return handle + + # ── Retrieval ───────────────────────────────────────────────────────────── + + def get(self, handle_id: str) -> Any: + """Retrieve raw data by handle ID. + + Args: + handle_id: The handle's unique identifier. + + Returns: + The stored data. + + Raises: + HandleNotFound: If the handle ID is unknown. + HandleExpired: If the handle's TTL has elapsed. + """ + handle = self._meta.get(handle_id) + if handle is None: + raise HandleNotFound(f"Handle '{handle_id}' not found. It may have been evicted.") + now = datetime.datetime.now(tz=datetime.timezone.utc) + if handle.expires_at <= now: + # Lazy eviction + del self._data[handle_id] + del self._meta[handle_id] + raise HandleExpired( + f"Handle '{handle_id}' expired at {handle.expires_at.isoformat()}." + ) + return self._data[handle_id] + + def get_meta(self, handle_id: str) -> Handle: + """Retrieve the :class:`Handle` metadata without fetching the data. + + Args: + handle_id: The handle's unique identifier. + + Returns: + The :class:`Handle` metadata. + + Raises: + HandleNotFound: If the handle ID is unknown. + """ + handle = self._meta.get(handle_id) + if handle is None: + raise HandleNotFound(f"Handle '{handle_id}' not found.") + return handle + + # ── Expand ──────────────────────────────────────────────────────────────── + + def expand( + self, + handle: Handle, + *, + query: dict[str, Any], + action_id: str = "", + response_mode: ResponseMode = "table", + ) -> Frame: + """Expand a handle with optional pagination, field selection, and filtering. + + Supported query keys: + - ``offset`` (int): Skip this many rows. + - ``limit`` (int): Return at most this many rows. + - ``fields`` (list[str]): Only include these fields. + - ``filter`` (dict[str, Any]): Basic equality filter (all conditions AND-ed). + + Args: + handle: The handle to expand. + query: Query parameters controlling the expansion. + action_id: Audit action ID to embed in the returned Frame. + response_mode: Response mode for the returned Frame. + + Returns: + A :class:`Frame` containing the slice of data. + + Raises: + HandleNotFound: If the handle ID is unknown. + HandleExpired: If the handle's TTL has elapsed. + """ + data = self.get(handle.handle_id) + rows: list[Any] = data if isinstance(data, list) else [data] + + # ── Filtering ────────────────────────────────────────────────────────── + filter_spec: dict[str, Any] = query.get("filter", {}) + if filter_spec and isinstance(filter_spec, dict): + rows = [ + r + for r in rows + if isinstance(r, dict) and all(r.get(k) == v for k, v in filter_spec.items()) + ] + + # ── Pagination ──────────────────────────────────────────────────────── + offset = int(query.get("offset", 0)) + limit = int(query.get("limit", len(rows))) + rows = rows[offset : offset + limit] + + # ── Field selection ─────────────────────────────────────────────────── + fields: list[str] = list(query.get("fields", [])) + if fields: + rows = [ + {k: v for k, v in r.items() if k in fields} if isinstance(r, dict) else r + for r in rows + ] + + return Frame( + action_id=action_id, + capability_id=handle.capability_id, + response_mode=response_mode, + table_preview=rows + if isinstance(rows[0], dict) + else [{"value": r} for r in rows] + if rows + else [], + handle=handle, + provenance=Provenance( + capability_id=handle.capability_id, + principal_id="", + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + action_id=action_id, + ), + ) + + # ── Maintenance ─────────────────────────────────────────────────────────── + + def evict_expired(self) -> int: + """Remove all expired handles from the store. + + Returns: + The number of handles evicted. + """ + now = datetime.datetime.now(tz=datetime.timezone.utc) + expired = [hid for hid, h in self._meta.items() if h.expires_at <= now] + for hid in expired: + self._data.pop(hid, None) + self._meta.pop(hid, None) + return len(expired) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py new file mode 100644 index 0000000..f7c1a5d --- /dev/null +++ b/src/agent_kernel/kernel.py @@ -0,0 +1,313 @@ +"""The Kernel: the main entry point for agent-kernel.""" + +from __future__ import annotations + +import datetime +import uuid +from typing import Any + +from .drivers.base import Driver, ExecutionContext +from .errors import DriverError +from .firewall.transform import Firewall +from .handles import HandleStore +from .models import ( + ActionTrace, + CapabilityGrant, + CapabilityRequest, + Frame, + Handle, + Principal, + ResponseMode, + RoutePlan, +) +from .policy import DefaultPolicyEngine, PolicyEngine +from .registry import CapabilityRegistry +from .router import Router, StaticRouter +from .tokens import CapabilityToken, HMACTokenProvider, TokenProvider +from .trace import TraceStore + + +class Kernel: + """The central orchestrator for capability-based AI agent security. + + The Kernel wires together the registry, policy engine, token provider, + router, firewall, handle store, and trace store into a single coherent + interface. + + Example:: + + registry = CapabilityRegistry() + registry.register(Capability(...)) + kernel = Kernel(registry) + + requests = kernel.request_capabilities("list invoices") + grant = kernel.grant_capability(requests[0], principal, justification="...") + frame = await kernel.invoke(grant.token, args={"operation": "list_invoices"}) + """ + + def __init__( + self, + registry: CapabilityRegistry, + policy: PolicyEngine | None = None, + token_provider: TokenProvider | None = None, + router: Router | None = None, + firewall: Firewall | None = None, + handle_store: HandleStore | None = None, + trace_store: TraceStore | None = None, + ) -> None: + self._registry = registry + self._policy: PolicyEngine = policy or DefaultPolicyEngine() + self._token_provider: TokenProvider = token_provider or HMACTokenProvider() + self._router: Router = router or StaticRouter() + self._firewall = firewall or Firewall() + self._handle_store = handle_store or HandleStore() + self._trace_store = trace_store or TraceStore() + self._drivers: dict[str, Driver] = {} + + # ── Driver registration ──────────────────────────────────────────────────── + + def register_driver(self, driver: Driver) -> None: + """Register a driver with the kernel. + + Args: + driver: Any object implementing the :class:`~agent_kernel.drivers.base.Driver` protocol. + """ + self._drivers[driver.driver_id] = driver + + # ── Public API ───────────────────────────────────────────────────────────── + + def request_capabilities( + self, + goal: str, + *, + context_tags: dict[str, str] | None = None, + ) -> list[CapabilityRequest]: + """Discover capabilities that match a natural-language goal. + + Args: + goal: Free-text description of the agent's intent. + context_tags: Optional metadata to narrow the search (currently unused). + + Returns: + An ordered list of :class:`CapabilityRequest` objects (best match first). + """ + return self._registry.search(goal) + + def grant_capability( + self, + request: CapabilityRequest, + principal: Principal, + *, + justification: str, + ) -> CapabilityGrant: + """Evaluate the policy and, if approved, issue a signed token. + + Args: + request: The capability request to evaluate. + principal: The principal requesting access. + justification: Free-text justification for the request. + + Returns: + A :class:`CapabilityGrant` containing the signed token. + + Raises: + PolicyDenied: If the policy engine rejects the request. + CapabilityNotFound: If the requested capability is not registered. + """ + capability = self._registry.get(request.capability_id) + decision = self._policy.evaluate( + request, capability, principal, justification=justification + ) + audit_id = str(uuid.uuid4()) + token = self._token_provider.issue( + capability.capability_id, + principal.principal_id, + constraints=decision.constraints, + audit_id=audit_id, + ) + return CapabilityGrant( + request=request, + principal=principal, + decision=decision, + token_id=token.token_id, + ) + + def get_token( + self, + request: CapabilityRequest, + principal: Principal, + *, + justification: str, + ) -> CapabilityToken: + """Like :meth:`grant_capability` but returns the token directly. + + This is a convenience method for use in :meth:`invoke`. + + Args: + request: The capability request. + principal: The requesting principal. + justification: Free-text justification. + + Returns: + A signed :class:`CapabilityToken`. + + Raises: + PolicyDenied: If the policy engine rejects the request. + CapabilityNotFound: If the capability is not registered. + """ + capability = self._registry.get(request.capability_id) + decision = self._policy.evaluate( + request, capability, principal, justification=justification + ) + audit_id = str(uuid.uuid4()) + return self._token_provider.issue( + capability.capability_id, + principal.principal_id, + constraints=decision.constraints, + audit_id=audit_id, + ) + + async def invoke( + self, + token: CapabilityToken, + *, + principal: Principal, + args: dict[str, Any], + response_mode: ResponseMode = "summary", + ) -> Frame: + """Execute a capability using a signed token and return a Frame. + + Args: + token: A signed :class:`CapabilityToken` authorising the invocation. + principal: The principal invoking the capability (must match token). + args: Arguments passed to the driver. + response_mode: How to present the result (``summary``, ``table``, + ``handle_only``, or ``raw``). + + Returns: + A bounded :class:`Frame` (never raw driver output). + + Raises: + TokenExpired: If the token has expired. + TokenInvalid: If the token signature does not verify. + TokenScopeError: If the token belongs to a different principal or capability. + CapabilityNotFound: If the capability is not registered. + DriverError: If all drivers fail. + """ + # ── Verify token ────────────────────────────────────────────────────── + self._token_provider.verify( + token, + expected_principal_id=principal.principal_id, + expected_capability_id=token.capability_id, + ) + + action_id = str(uuid.uuid4()) + self._registry.get(token.capability_id) # validate capability exists + plan: RoutePlan = self._router.route(token.capability_id) + + # ── Execute with fallback ───────────────────────────────────────────── + raw_result = None + used_driver_id = "" + last_error: Exception | None = None + + for driver_id in plan.driver_ids: + driver = self._drivers.get(driver_id) + if driver is None: + continue + ctx = ExecutionContext( + capability_id=token.capability_id, + principal_id=principal.principal_id, + args=args, + constraints=token.constraints, + action_id=action_id, + ) + try: + raw_result = await driver.execute(ctx) + used_driver_id = driver_id + break + except DriverError as exc: + last_error = exc + continue + + if raw_result is None: + err_msg = str(last_error) if last_error else "No drivers available." + trace = ActionTrace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args=args, + response_mode=response_mode, + driver_id="", + error=err_msg, + ) + self._trace_store.record(trace) + raise DriverError( + f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" + ) + + # ── Store handle ────────────────────────────────────────────────────── + handle: Handle | None = None + if response_mode != "raw": + handle = self._handle_store.store( + capability_id=token.capability_id, + data=raw_result.data, + ) + + # ── Firewall transform ──────────────────────────────────────────────── + frame = self._firewall.transform( + raw_result, + action_id=action_id, + principal_id=principal.principal_id, + principal_roles=list(principal.roles), + response_mode=response_mode, + constraints=token.constraints, + handle=handle, + ) + + # ── Record trace ────────────────────────────────────────────────────── + trace = ActionTrace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args=args, + response_mode=response_mode, + driver_id=used_driver_id, + handle_id=handle.handle_id if handle else None, + ) + self._trace_store.record(trace) + + return frame + + def expand(self, handle: Handle, *, query: dict[str, Any]) -> Frame: + """Expand a handle with pagination, field selection, or filtering. + + Args: + handle: The :class:`Handle` to expand. + query: Query parameters (``offset``, ``limit``, ``fields``, ``filter``). + + Returns: + A :class:`Frame` with the requested slice of data. + + Raises: + HandleNotFound: If the handle is unknown. + HandleExpired: If the handle has expired. + """ + return self._handle_store.expand(handle, query=query) + + def explain(self, action_id: str) -> ActionTrace: + """Retrieve the audit trace for a past invocation. + + Args: + action_id: The unique action identifier returned in a :class:`Frame`. + + Returns: + The :class:`ActionTrace` for that action. + + Raises: + AgentKernelError: If no trace exists for that action ID. + """ + return self._trace_store.get(action_id) diff --git a/src/agent_kernel/models.py b/src/agent_kernel/models.py new file mode 100644 index 0000000..60f9ba9 --- /dev/null +++ b/src/agent_kernel/models.py @@ -0,0 +1,234 @@ +"""Core dataclasses for agent-kernel. + +All types use ``dataclasses.dataclass`` with ``slots=True`` where supported +(Python ≥ 3.10) for minimal memory footprint and fast attribute access. +""" + +from __future__ import annotations + +import datetime +from dataclasses import dataclass, field +from typing import Any, Literal + +from .enums import SafetyClass, SensitivityTag + +ResponseMode = Literal["summary", "table", "handle_only", "raw"] + + +# ── Capability ──────────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class ImplementationRef: + """Points a capability at a concrete driver + operation.""" + + driver_id: str + """Identifier of the driver that handles this capability (e.g. ``"memory"``).""" + + operation: str + """Operation name understood by the driver (e.g. ``"list_invoices"``).""" + + +@dataclass(slots=True) +class Capability: + """A task-shaped unit of work that can be authorized and executed.""" + + capability_id: str + """Stable, human-readable identifier (e.g. ``"billing.list_invoices"``).""" + + name: str + """Short human-readable name.""" + + description: str + """What the capability does.""" + + safety_class: SafetyClass + """READ / WRITE / DESTRUCTIVE.""" + + sensitivity: SensitivityTag = SensitivityTag.NONE + """Optional sensitivity tag.""" + + allowed_fields: list[str] = field(default_factory=list) + """If non-empty, only these fields are returned unless the caller has ``pii_reader``.""" + + tags: list[str] = field(default_factory=list) + """Arbitrary keyword tags used for capability matching.""" + + impl: ImplementationRef | None = None + """Optional pointer to the implementation.""" + + +# ── Request / Grant ─────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class CapabilityRequest: + """A request for authorization to use a capability.""" + + capability_id: str + """The capability being requested.""" + + goal: str + """Free-text description of why this capability is needed.""" + + constraints: dict[str, Any] = field(default_factory=dict) + """Optional execution constraints (e.g. ``{"max_rows": 10}``).""" + + +@dataclass(slots=True) +class Principal: + """Represents the entity (agent, user, service) making a request.""" + + principal_id: str + """Unique identifier (UUID or slug).""" + + roles: list[str] = field(default_factory=list) + """Role strings, e.g. ``["reader", "admin"]``.""" + + attributes: dict[str, str] = field(default_factory=dict) + """Arbitrary attributes, e.g. ``{"tenant": "acme"}``.""" + + +@dataclass(slots=True) +class PolicyDecision: + """Result of a policy engine evaluation.""" + + allowed: bool + """``True`` if the request is permitted.""" + + reason: str + """Human-readable explanation.""" + + constraints: dict[str, Any] = field(default_factory=dict) + """Any additional constraints imposed by the policy (e.g. ``max_rows``).""" + + +@dataclass(slots=True) +class CapabilityGrant: + """A signed authorization binding a principal to a capability.""" + + request: CapabilityRequest + """The original request.""" + + principal: Principal + """The principal this grant is issued to.""" + + decision: PolicyDecision + """The policy decision that led to this grant.""" + + token_id: str + """The token's unique identifier.""" + + +# ── Routing ─────────────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class RoutePlan: + """Maps a capability to an ordered list of driver IDs to try.""" + + capability_id: str + driver_ids: list[str] + """Ordered list; first that succeeds wins.""" + + +# ── Raw results & Frames ────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class RawResult: + """Unfiltered output from a driver execution.""" + + capability_id: str + data: Any + """Arbitrary data returned by the driver.""" + + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(slots=True) +class Handle: + """An opaque reference to a full dataset stored in the HandleStore.""" + + handle_id: str + capability_id: str + created_at: datetime.datetime + expires_at: datetime.datetime + total_rows: int = 0 + + +@dataclass(slots=True) +class Provenance: + """Tracks the origin of information in a Frame.""" + + capability_id: str + principal_id: str + invoked_at: datetime.datetime + action_id: str + + +@dataclass(slots=True) +class Budgets: + """Budget constraints for the context firewall.""" + + max_rows: int = 50 + max_fields: int = 20 + max_chars: int = 4000 + max_depth: int = 3 + + +@dataclass(slots=True) +class FieldSpec: + """Describes a single field in a structured result.""" + + name: str + value_type: str + + +@dataclass(slots=True) +class Frame: + """Bounded, LLM-safe representation of a capability result. + + The firewall always returns a Frame; raw data is never passed to the LLM. + """ + + action_id: str + capability_id: str + response_mode: ResponseMode + + facts: list[str] = field(default_factory=list) + """Key facts extracted from the result (≤ 20 items).""" + + table_preview: list[dict[str, Any]] = field(default_factory=list) + """Tabular preview (≤ max_rows rows).""" + + handle: Handle | None = None + """Opaque reference to the full dataset for later expansion.""" + + warnings: list[str] = field(default_factory=list) + """Non-fatal warnings (e.g. redacted fields).""" + + provenance: Provenance | None = None + """Audit provenance of this frame.""" + + raw_data: Any = None + """Only populated in ``raw`` response mode for admin principals.""" + + +# ── Audit trace ─────────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class ActionTrace: + """Complete audit record for a single kernel invocation.""" + + action_id: str + capability_id: str + principal_id: str + token_id: str + invoked_at: datetime.datetime + args: dict[str, Any] + response_mode: ResponseMode + driver_id: str + handle_id: str | None = None + error: str | None = None diff --git a/src/agent_kernel/policy.py b/src/agent_kernel/policy.py new file mode 100644 index 0000000..7aecf41 --- /dev/null +++ b/src/agent_kernel/policy.py @@ -0,0 +1,135 @@ +"""Policy engine: role-based access control with confused-deputy prevention.""" + +from __future__ import annotations + +from typing import Any, Protocol + +from .enums import SafetyClass, SensitivityTag +from .errors import PolicyDenied +from .models import Capability, CapabilityRequest, PolicyDecision, Principal + +# Minimum justification length for WRITE operations. +_MIN_JUSTIFICATION = 15 + +# Default max_rows caps. +_MAX_ROWS_USER = 50 +_MAX_ROWS_SERVICE = 500 + + +class PolicyEngine(Protocol): + """Interface for a policy engine.""" + + def evaluate( + self, + request: CapabilityRequest, + capability: Capability, + principal: Principal, + *, + justification: str, + ) -> PolicyDecision: + """Evaluate whether *principal* may perform *request* on *capability*. + + Args: + request: The capability request being evaluated. + capability: The target capability. + principal: The requesting principal. + justification: Free-text justification from the caller. + + Returns: + A :class:`PolicyDecision` (allowed or denied with reason). + """ + ... + + +class DefaultPolicyEngine: + """Rule-based policy engine implementing the default access control policy. + + Rules (evaluated in order): + + 1. **READ** — always allowed. + 2. **WRITE** — requires: + - ``justification`` of at least 15 characters. + - Principal role ``"writer"`` **or** ``"admin"``. + 3. **DESTRUCTIVE** — requires principal role ``"admin"``. + 4. **PII / PCI sensitivity** — requires the ``tenant`` attribute on the + principal. Enforces ``allowed_fields`` unless the principal has the + ``pii_reader`` role. + 5. **max_rows** — 50 for regular users; 500 for principals with the + ``"service"`` role. + """ + + def evaluate( + self, + request: CapabilityRequest, + capability: Capability, + principal: Principal, + *, + justification: str, + ) -> PolicyDecision: + """Evaluate the request against the default policy rules. + + Args: + request: The capability request being evaluated. + capability: The target capability. + principal: The requesting principal. + justification: Free-text justification from the caller. + + Returns: + :class:`PolicyDecision` with ``allowed=True`` and any imposed + constraints, or raises :class:`PolicyDenied`. + + Raises: + PolicyDenied: When the request violates a policy rule. + """ + roles = set(principal.roles) + constraints: dict[str, Any] = dict(request.constraints) + + # ── Safety class checks ─────────────────────────────────────────────── + + if capability.safety_class == SafetyClass.WRITE: + if len(justification) < _MIN_JUSTIFICATION: + raise PolicyDenied( + f"WRITE capabilities require a justification of at least " + f"{_MIN_JUSTIFICATION} characters. " + f"Got {len(justification)} characters." + ) + if not (roles & {"writer", "admin"}): + raise PolicyDenied( + f"WRITE capabilities require the 'writer' or 'admin' role. " + f"Principal '{principal.principal_id}' has roles: {sorted(roles)}." + ) + + elif capability.safety_class == SafetyClass.DESTRUCTIVE: + if "admin" not in roles: + raise PolicyDenied( + f"DESTRUCTIVE capabilities require the 'admin' role. " + f"Principal '{principal.principal_id}' has roles: {sorted(roles)}." + ) + + # ── Sensitivity checks ──────────────────────────────────────────────── + + if capability.sensitivity in (SensitivityTag.PII, SensitivityTag.PCI): + if "tenant" not in principal.attributes: + raise PolicyDenied( + f"Capability '{capability.capability_id}' has " + f"{capability.sensitivity.value} sensitivity and requires " + "the principal to have a 'tenant' attribute." + ) + # Enforce allowed_fields unless the principal is a pii_reader. + if capability.allowed_fields and "pii_reader" not in roles: + constraints["allowed_fields"] = capability.allowed_fields + + # ── Row cap ─────────────────────────────────────────────────────────── + + max_rows = _MAX_ROWS_SERVICE if "service" in roles else _MAX_ROWS_USER + # Respect any tighter constraint from the request itself. + if "max_rows" in constraints: + constraints["max_rows"] = min(int(constraints["max_rows"]), max_rows) + else: + constraints["max_rows"] = max_rows + + return PolicyDecision( + allowed=True, + reason="Request approved by DefaultPolicyEngine.", + constraints=constraints, + ) diff --git a/src/agent_kernel/py.typed b/src/agent_kernel/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/agent_kernel/registry.py b/src/agent_kernel/registry.py new file mode 100644 index 0000000..48515ff --- /dev/null +++ b/src/agent_kernel/registry.py @@ -0,0 +1,124 @@ +"""Capability registry: register, lookup, and keyword-based matching.""" + +from __future__ import annotations + +import re + +from .errors import CapabilityNotFound +from .models import Capability, CapabilityRequest + + +class CapabilityRegistry: + """Stores and retrieves :class:`Capability` objects. + + Capabilities are registered by their ``capability_id`` and can be looked + up directly or discovered via keyword search against the goal description. + """ + + def __init__(self) -> None: + self._store: dict[str, Capability] = {} + + # ── Registration ────────────────────────────────────────────────────────── + + def register(self, capability: Capability) -> None: + """Register a capability. + + Args: + capability: The :class:`Capability` to register. + + Raises: + ValueError: If a capability with the same ID is already registered. + """ + if capability.capability_id in self._store: + raise ValueError( + f"Capability '{capability.capability_id}' is already registered. " + "Use a unique capability_id." + ) + self._store[capability.capability_id] = capability + + def register_many(self, capabilities: list[Capability]) -> None: + """Register multiple capabilities at once. + + Args: + capabilities: List of :class:`Capability` objects to register. + """ + for cap in capabilities: + self.register(cap) + + # ── Lookup ──────────────────────────────────────────────────────────────── + + def get(self, capability_id: str) -> Capability: + """Retrieve a capability by its ID. + + Args: + capability_id: The capability's stable identifier. + + Returns: + The matching :class:`Capability`. + + Raises: + CapabilityNotFound: If no capability with that ID exists. + """ + try: + return self._store[capability_id] + except KeyError: + raise CapabilityNotFound( + f"No capability registered with id='{capability_id}'. " + "Check the capability_id or register it first." + ) from None + + def list_all(self) -> list[Capability]: + """Return all registered capabilities in registration order.""" + return list(self._store.values()) + + # ── Keyword matching ────────────────────────────────────────────────────── + + def search(self, goal: str, *, max_results: int = 10) -> list[CapabilityRequest]: + """Search for capabilities matching a goal string. + + Splits *goal* into tokens and scores capabilities by how many tokens + appear in their ``capability_id``, ``name``, ``description``, or + ``tags``. Returns the top results as :class:`CapabilityRequest` objects. + + Args: + goal: Free-text description of the user's intent. + max_results: Maximum number of results to return. + + Returns: + Ordered list (highest score first) of :class:`CapabilityRequest`. + """ + tokens = self._tokenize(goal) + if not tokens: + return [] + + scored: list[tuple[int, Capability]] = [] + for cap in self._store.values(): + score = self._score(cap, tokens) + if score > 0: + scored.append((score, cap)) + + scored.sort(key=lambda x: (-x[0], x[1].capability_id)) + return [ + CapabilityRequest(capability_id=cap.capability_id, goal=goal) + for _, cap in scored[:max_results] + ] + + # ── Helpers ─────────────────────────────────────────────────────────────── + + @staticmethod + def _tokenize(text: str) -> list[str]: + """Split text into lower-case word tokens.""" + return re.findall(r"[a-z0-9]+", text.lower()) + + @staticmethod + def _score(cap: Capability, tokens: list[str]) -> int: + """Return a match score for a capability against query tokens.""" + corpus = " ".join( + [ + cap.capability_id, + cap.name, + cap.description, + ] + + cap.tags + ).lower() + return sum(1 for t in tokens if t in corpus) diff --git a/src/agent_kernel/router.py b/src/agent_kernel/router.py new file mode 100644 index 0000000..3da2e75 --- /dev/null +++ b/src/agent_kernel/router.py @@ -0,0 +1,65 @@ +"""Router: maps a capability to an ordered list of drivers to try.""" + +from __future__ import annotations + +from typing import Protocol + +from .models import RoutePlan + + +class Router(Protocol): + """Interface for routing a capability invocation to drivers.""" + + def route(self, capability_id: str) -> RoutePlan: + """Return an ordered list of driver IDs to try for *capability_id*. + + Args: + capability_id: The capability being invoked. + + Returns: + A :class:`RoutePlan` with an ordered ``driver_ids`` list. + """ + ... + + +class StaticRouter: + """A router backed by a static mapping of capability → driver IDs. + + Capabilities not in the explicit map fall back to a configurable default + driver list (e.g. ``["memory"]``). + """ + + def __init__( + self, + routes: dict[str, list[str]] | None = None, + fallback: list[str] | None = None, + ) -> None: + """Initialise the router. + + Args: + routes: Explicit ``{capability_id: [driver_id, ...]}`` mapping. + fallback: Driver IDs to use when no explicit route is found. + """ + self._routes: dict[str, list[str]] = routes or {} + self._fallback: list[str] = fallback or ["memory"] + + def add_route(self, capability_id: str, driver_ids: list[str]) -> None: + """Add or replace a route. + + Args: + capability_id: The capability to route. + driver_ids: Ordered list of driver IDs. + """ + self._routes[capability_id] = driver_ids + + def route(self, capability_id: str) -> RoutePlan: + """Return a :class:`RoutePlan` for *capability_id*. + + Args: + capability_id: The capability being invoked. + + Returns: + The explicit route if defined, otherwise the fallback route. + """ + driver_ids = self._routes.get(capability_id, self._fallback) + return RoutePlan(capability_id=capability_id, driver_ids=list(driver_ids)) diff --git a/src/agent_kernel/tokens.py b/src/agent_kernel/tokens.py new file mode 100644 index 0000000..5520ebb --- /dev/null +++ b/src/agent_kernel/tokens.py @@ -0,0 +1,254 @@ +"""HMAC-SHA256 token provider for capability authorization.""" + +from __future__ import annotations + +import datetime +import hashlib +import hmac +import json +import logging +import os +import secrets +import uuid +from dataclasses import dataclass, field +from typing import Any, Protocol + +from .errors import TokenExpired, TokenInvalid, TokenScopeError + +logger = logging.getLogger(__name__) + +_DEV_SECRET: str | None = None + + +def _get_secret() -> str: + """Return the HMAC secret from the environment or generate a dev fallback.""" + global _DEV_SECRET + secret = os.environ.get("AGENT_KERNEL_SECRET") + if secret: + return secret + if _DEV_SECRET is None: + _DEV_SECRET = secrets.token_hex(32) + logger.warning( + "AGENT_KERNEL_SECRET is not set. " + "Using a random development secret — tokens will not survive restarts. " + "Set AGENT_KERNEL_SECRET in production." + ) + return _DEV_SECRET + + +# ── Token dataclass ─────────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class CapabilityToken: + """A signed, time-bounded, principal-scoped authorization token. + + Warning: + Tokens are tamper-evident (HMAC-SHA256) but **not encrypted**. + Do not put sensitive data in token fields. + """ + + token_id: str + capability_id: str + principal_id: str + issued_at: datetime.datetime + expires_at: datetime.datetime + constraints: dict[str, Any] = field(default_factory=dict) + audit_id: str = "" + signature: str = "" + + # ── Serialization ───────────────────────────────────────────────────────── + + def _signable_payload(self) -> str: + """Return the canonical JSON string used as the HMAC message.""" + payload = { + "token_id": self.token_id, + "capability_id": self.capability_id, + "principal_id": self.principal_id, + "issued_at": self.issued_at.isoformat(), + "expires_at": self.expires_at.isoformat(), + "constraints": self.constraints, + "audit_id": self.audit_id, + } + return json.dumps(payload, sort_keys=True, separators=(",", ":")) + + def to_dict(self) -> dict[str, Any]: + """Serialise the token to a plain dict (suitable for JSON transport).""" + return { + "token_id": self.token_id, + "capability_id": self.capability_id, + "principal_id": self.principal_id, + "issued_at": self.issued_at.isoformat(), + "expires_at": self.expires_at.isoformat(), + "constraints": self.constraints, + "audit_id": self.audit_id, + "signature": self.signature, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CapabilityToken: + """Reconstruct a token from a plain dict.""" + return cls( + token_id=data["token_id"], + capability_id=data["capability_id"], + principal_id=data["principal_id"], + issued_at=datetime.datetime.fromisoformat(data["issued_at"]), + expires_at=datetime.datetime.fromisoformat(data["expires_at"]), + constraints=data.get("constraints", {}), + audit_id=data.get("audit_id", ""), + signature=data.get("signature", ""), + ) + + +# ── Protocol ────────────────────────────────────────────────────────────────── + + +class TokenProvider(Protocol): + """Interface for token issuance and verification.""" + + def issue( + self, + capability_id: str, + principal_id: str, + *, + constraints: dict[str, Any] | None = None, + ttl_seconds: int = 3600, + audit_id: str = "", + ) -> CapabilityToken: + """Issue a new token. + + Args: + capability_id: The capability this token authorises. + principal_id: The principal this token is issued to. + constraints: Optional execution constraints. + ttl_seconds: How long the token is valid (default 1 hour). + audit_id: Audit trail ID to embed in the token. + + Returns: + A freshly signed :class:`CapabilityToken`. + """ + ... + + def verify( + self, + token: CapabilityToken, + *, + expected_principal_id: str, + expected_capability_id: str, + ) -> None: + """Verify a token. + + Args: + token: The token to verify. + expected_principal_id: The principal that should own this token. + expected_capability_id: The capability this token should authorize. + + Raises: + TokenExpired: If the token has expired. + TokenInvalid: If the signature does not verify. + TokenScopeError: If the principal or capability do not match. + """ + ... + + +# ── Implementation ──────────────────────────────────────────────────────────── + + +class HMACTokenProvider: + """Issues and verifies HMAC-SHA256 capability tokens. + + The signing secret is read from the ``AGENT_KERNEL_SECRET`` environment + variable. If the variable is absent a random development secret is + generated and a warning is logged. + """ + + def __init__(self, secret: str | None = None) -> None: + self._secret = secret # None → use env / dev fallback at call time + + def _secret_bytes(self) -> bytes: + return (self._secret or _get_secret()).encode() + + def _sign(self, payload: str) -> str: + return hmac.new(self._secret_bytes(), payload.encode(), hashlib.sha256).hexdigest() + + def issue( + self, + capability_id: str, + principal_id: str, + *, + constraints: dict[str, Any] | None = None, + ttl_seconds: int = 3600, + audit_id: str = "", + ) -> CapabilityToken: + """Issue a new signed token. + + Args: + capability_id: The capability this token authorises. + principal_id: The principal this token is issued to. + constraints: Optional execution constraints. + ttl_seconds: How long the token is valid (default 1 hour). + audit_id: Audit trail ID to embed in the token. + + Returns: + A freshly signed :class:`CapabilityToken`. + """ + now = datetime.datetime.now(tz=datetime.timezone.utc) + token = CapabilityToken( + token_id=str(uuid.uuid4()), + capability_id=capability_id, + principal_id=principal_id, + issued_at=now, + expires_at=now + datetime.timedelta(seconds=ttl_seconds), + constraints=constraints or {}, + audit_id=audit_id, + ) + token.signature = self._sign(token._signable_payload()) + return token + + def verify( + self, + token: CapabilityToken, + *, + expected_principal_id: str, + expected_capability_id: str, + ) -> None: + """Verify a token's signature, expiry, and scope bindings. + + Args: + token: The token to verify. + expected_principal_id: The principal that should own this token. + expected_capability_id: The capability this token should authorize. + + Raises: + TokenExpired: If ``token.expires_at`` is in the past. + TokenInvalid: If the HMAC signature does not verify. + TokenScopeError: If principal or capability do not match. + """ + # 1. Expiry + now = datetime.datetime.now(tz=datetime.timezone.utc) + if token.expires_at <= now: + raise TokenExpired( + f"Token '{token.token_id}' expired at {token.expires_at.isoformat()}." + ) + + # 2. Signature + expected_sig = self._sign(token._signable_payload()) + if not hmac.compare_digest(expected_sig, token.signature): + raise TokenInvalid( + f"Token '{token.token_id}' has an invalid signature. " + "The token may have been tampered with." + ) + + # 3. Principal binding (confused-deputy prevention) + if token.principal_id != expected_principal_id: + raise TokenScopeError( + f"Token '{token.token_id}' was issued for principal " + f"'{token.principal_id}', not '{expected_principal_id}'." + ) + + # 4. Capability binding + if token.capability_id != expected_capability_id: + raise TokenScopeError( + f"Token '{token.token_id}' was issued for capability " + f"'{token.capability_id}', not '{expected_capability_id}'." + ) diff --git a/src/agent_kernel/trace.py b/src/agent_kernel/trace.py new file mode 100644 index 0000000..ccbd2bd --- /dev/null +++ b/src/agent_kernel/trace.py @@ -0,0 +1,46 @@ +"""TraceStore: in-memory audit trail for kernel invocations.""" + +from __future__ import annotations + +from .errors import AgentKernelError +from .models import ActionTrace + + +class TraceStore: + """Stores :class:`ActionTrace` records indexed by ``action_id``. + + All invocations recorded by the :class:`~agent_kernel.kernel.Kernel` are + retrievable here for audit and explainability purposes. + """ + + def __init__(self) -> None: + self._traces: dict[str, ActionTrace] = {} + + def record(self, trace: ActionTrace) -> None: + """Store an action trace. + + Args: + trace: The :class:`ActionTrace` to record. + """ + self._traces[trace.action_id] = trace + + def get(self, action_id: str) -> ActionTrace: + """Retrieve an action trace by its ID. + + Args: + action_id: The unique action identifier. + + Returns: + The :class:`ActionTrace` for that action. + + Raises: + AgentKernelError: If no trace with that ID exists. + """ + try: + return self._traces[action_id] + except KeyError: + raise AgentKernelError(f"No action trace found for action_id='{action_id}'.") from None + + def list_all(self) -> list[ActionTrace]: + """Return all recorded traces in insertion order.""" + return list(self._traces.values()) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0f74758 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,172 @@ +"""Shared test fixtures for agent-kernel tests.""" + +from __future__ import annotations + +import pytest + +from agent_kernel import ( + Capability, + CapabilityRegistry, + HMACTokenProvider, + InMemoryDriver, + Kernel, + Principal, + SafetyClass, + SensitivityTag, + StaticRouter, + make_billing_driver, +) +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.models import ImplementationRef + +# ── Capabilities ─────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def capabilities() -> list[Capability]: + return [ + Capability( + capability_id="billing.list_invoices", + name="List Invoices", + description="List all invoices for a customer", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + allowed_fields=["id", "amount", "currency", "status", "date"], + tags=["billing", "invoices", "list"], + impl=ImplementationRef(driver_id="billing", operation="list_invoices"), + ), + Capability( + capability_id="billing.get_invoice", + name="Get Invoice", + description="Get a single invoice by ID", + safety_class=SafetyClass.READ, + sensitivity=SensitivityTag.PII, + allowed_fields=["id", "amount", "currency", "status", "date", "line_items"], + tags=["billing", "invoice", "get", "detail"], + impl=ImplementationRef(driver_id="billing", operation="get_invoice"), + ), + Capability( + capability_id="billing.summarize_spend", + name="Summarize Spend", + description="Summarize total spend by currency and status", + safety_class=SafetyClass.READ, + tags=["billing", "summary", "spend", "analytics"], + impl=ImplementationRef(driver_id="billing", operation="summarize_spend"), + ), + Capability( + capability_id="billing.update_invoice", + name="Update Invoice", + description="Update an existing invoice", + safety_class=SafetyClass.WRITE, + tags=["billing", "invoice", "update", "write"], + impl=ImplementationRef(driver_id="billing", operation="update_invoice"), + ), + Capability( + capability_id="billing.delete_invoice", + name="Delete Invoice", + description="Permanently delete an invoice", + safety_class=SafetyClass.DESTRUCTIVE, + tags=["billing", "invoice", "delete", "destructive"], + impl=ImplementationRef(driver_id="billing", operation="delete_invoice"), + ), + ] + + +@pytest.fixture() +def registry(capabilities: list[Capability]) -> CapabilityRegistry: + reg = CapabilityRegistry() + reg.register_many(capabilities) + return reg + + +# ── Principals ───────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def reader_principal() -> Principal: + return Principal( + principal_id="user-reader-001", + roles=["reader"], + attributes={"tenant": "acme"}, + ) + + +@pytest.fixture() +def writer_principal() -> Principal: + return Principal( + principal_id="user-writer-001", + roles=["reader", "writer"], + attributes={"tenant": "acme"}, + ) + + +@pytest.fixture() +def admin_principal() -> Principal: + return Principal( + principal_id="user-admin-001", + roles=["reader", "writer", "admin"], + attributes={"tenant": "acme"}, + ) + + +@pytest.fixture() +def service_principal() -> Principal: + return Principal( + principal_id="svc-analytics-001", + roles=["reader", "service"], + attributes={"tenant": "acme"}, + ) + + +# ── Drivers ──────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def billing_driver() -> InMemoryDriver: + return make_billing_driver() + + +@pytest.fixture() +def memory_driver() -> InMemoryDriver: + driver = InMemoryDriver(driver_id="memory") + + def echo(ctx: ExecutionContext) -> dict[str, object]: + return {"echo": ctx.args, "capability_id": ctx.capability_id} + + driver.register_handler("billing.list_invoices", echo) + driver.register_handler("billing.get_invoice", echo) + driver.register_handler("billing.summarize_spend", echo) + driver.register_handler("billing.update_invoice", echo) + driver.register_handler("billing.delete_invoice", echo) + return driver + + +# ── Token provider ───────────────────────────────────────────────────────────── + + +@pytest.fixture() +def token_provider() -> HMACTokenProvider: + return HMACTokenProvider(secret="test-secret-do-not-use-in-prod") + + +# ── Kernel ───────────────────────────────────────────────────────────────────── + + +@pytest.fixture() +def kernel(registry: CapabilityRegistry, memory_driver: InMemoryDriver) -> Kernel: + router = StaticRouter( + routes={ + "billing.list_invoices": ["memory"], + "billing.get_invoice": ["memory"], + "billing.summarize_spend": ["memory"], + "billing.update_invoice": ["memory"], + "billing.delete_invoice": ["memory"], + } + ) + k = Kernel( + registry=registry, + token_provider=HMACTokenProvider(secret="test-secret-do-not-use-in-prod"), + router=router, + ) + k.register_driver(memory_driver) + return k diff --git a/tests/test_drivers.py b/tests/test_drivers.py new file mode 100644 index 0000000..adfad78 --- /dev/null +++ b/tests/test_drivers.py @@ -0,0 +1,185 @@ +"""Tests for InMemoryDriver and HTTPDriver.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agent_kernel import DriverError, InMemoryDriver +from agent_kernel.drivers.base import ExecutionContext +from agent_kernel.drivers.http import HTTPDriver, HTTPEndpoint + +# ── InMemoryDriver ───────────────────────────────────────────────────────────── + + +def test_inmemory_register_and_execute() -> None: + driver = InMemoryDriver() + + def handler(ctx: ExecutionContext) -> dict[str, Any]: + return {"result": "ok", "args": ctx.args} + + driver.register_handler("my_op", handler) + assert driver.driver_id == "memory" + + +@pytest.mark.asyncio +async def test_inmemory_execute_success() -> None: + driver = InMemoryDriver() + driver.register_handler("op1", lambda ctx: {"x": 1}) + ctx = ExecutionContext(capability_id="cap.x", principal_id="u1", args={"operation": "op1"}) + result = await driver.execute(ctx) + assert result.data == {"x": 1} + assert result.capability_id == "cap.x" + + +@pytest.mark.asyncio +async def test_inmemory_execute_fallback_to_capability_id() -> None: + driver = InMemoryDriver() + driver.register_handler("cap.x", lambda ctx: "direct") + ctx = ExecutionContext(capability_id="cap.x", principal_id="u1") + result = await driver.execute(ctx) + assert result.data == "direct" + + +@pytest.mark.asyncio +async def test_inmemory_execute_unknown_operation_raises() -> None: + driver = InMemoryDriver() + ctx = ExecutionContext(capability_id="cap.x", principal_id="u1", args={"operation": "noop"}) + with pytest.raises(DriverError, match="no handler"): + await driver.execute(ctx) + + +@pytest.mark.asyncio +async def test_inmemory_handler_exception_raises_driver_error() -> None: + driver = InMemoryDriver() + + def bad_handler(ctx: ExecutionContext) -> None: + raise RuntimeError("boom") + + driver.register_handler("bad_op", bad_handler) + ctx = ExecutionContext(capability_id="cap.x", principal_id="u1", args={"operation": "bad_op"}) + with pytest.raises(DriverError, match="boom"): + await driver.execute(ctx) + + +@pytest.mark.asyncio +async def test_billing_driver_list(billing_driver: InMemoryDriver) -> None: + ctx = ExecutionContext( + capability_id="billing.list_invoices", + principal_id="u1", + args={"operation": "list_invoices"}, + ) + result = await billing_driver.execute(ctx) + assert isinstance(result.data, list) + assert len(result.data) == 200 + + +@pytest.mark.asyncio +async def test_billing_driver_list_filtered(billing_driver: InMemoryDriver) -> None: + ctx = ExecutionContext( + capability_id="billing.list_invoices", + principal_id="u1", + args={"operation": "list_invoices", "status": "paid"}, + ) + result = await billing_driver.execute(ctx) + assert all(r["status"] == "paid" for r in result.data) + + +@pytest.mark.asyncio +async def test_billing_driver_get(billing_driver: InMemoryDriver) -> None: + ctx = ExecutionContext( + capability_id="billing.get_invoice", + principal_id="u1", + args={"operation": "get_invoice", "id": "INV-0001"}, + ) + result = await billing_driver.execute(ctx) + assert result.data is not None + assert result.data["id"] == "INV-0001" + + +@pytest.mark.asyncio +async def test_billing_driver_summarize(billing_driver: InMemoryDriver) -> None: + ctx = ExecutionContext( + capability_id="billing.summarize_spend", + principal_id="u1", + args={"operation": "summarize_spend"}, + ) + result = await billing_driver.execute(ctx) + assert "totals" in result.data + assert "invoice_count" in result.data + assert result.data["invoice_count"] == 200 + + +# ── HTTPDriver ───────────────────────────────────────────────────────────────── + + +def test_httpdriver_register_endpoint() -> None: + driver = HTTPDriver(driver_id="myhttp") + endpoint = HTTPEndpoint(url="http://example.com/api", method="GET") + driver.register_endpoint("op1", endpoint) + assert driver.driver_id == "myhttp" + + +@pytest.mark.asyncio +async def test_httpdriver_execute_get(monkeypatch: pytest.MonkeyPatch) -> None: + driver = HTTPDriver() + endpoint = HTTPEndpoint(url="http://localhost:9999/test", method="GET") + driver.register_endpoint("get_data", endpoint) + + mock_response = MagicMock() + mock_response.json.return_value = [{"id": 1}] + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("agent_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + ctx = ExecutionContext( + capability_id="cap.x", + principal_id="u1", + args={"operation": "get_data"}, + ) + result = await driver.execute(ctx) + assert result.data == [{"id": 1}] + + +@pytest.mark.asyncio +async def test_httpdriver_unknown_operation_raises() -> None: + driver = HTTPDriver() + ctx = ExecutionContext(capability_id="cap.x", principal_id="u1", args={"operation": "noop"}) + with pytest.raises(DriverError, match="no endpoint"): + await driver.execute(ctx) + + +@pytest.mark.asyncio +async def test_httpdriver_http_error_raises(monkeypatch: pytest.MonkeyPatch) -> None: + import httpx + + driver = HTTPDriver() + endpoint = HTTPEndpoint(url="http://localhost:9999/fail", method="GET") + driver.register_endpoint("fail_op", endpoint) + + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + error = httpx.HTTPStatusError("Server Error", request=MagicMock(), response=mock_response) + mock_client.get = AsyncMock(side_effect=error) + + with patch("agent_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + ctx = ExecutionContext( + capability_id="cap.x", + principal_id="u1", + args={"operation": "fail_op"}, + ) + with pytest.raises(DriverError, match="HTTP 500"): + await driver.execute(ctx) diff --git a/tests/test_firewall.py b/tests/test_firewall.py new file mode 100644 index 0000000..1a17975 --- /dev/null +++ b/tests/test_firewall.py @@ -0,0 +1,157 @@ +"""Tests for the context Firewall.""" + +from __future__ import annotations + +import datetime + +from agent_kernel import Firewall +from agent_kernel.firewall.budgets import Budgets +from agent_kernel.models import Handle, RawResult + + +def _handle() -> Handle: + now = datetime.datetime.now(tz=datetime.timezone.utc) + return Handle( + handle_id="h1", + capability_id="cap.x", + created_at=now, + expires_at=now + datetime.timedelta(hours=1), + total_rows=200, + ) + + +def _transform( + data: object, + response_mode: str = "summary", + *, + principal_roles: list[str] | None = None, + constraints: dict[str, object] | None = None, + budgets: Budgets | None = None, +) -> object: + fw = Firewall(budgets=budgets) + raw = RawResult(capability_id="cap.x", data=data) + return fw.transform( + raw, + action_id="act-1", + principal_id="u1", + principal_roles=principal_roles or [], + response_mode=response_mode, # type: ignore[arg-type] + constraints=constraints, + handle=_handle(), + ) + + +# ── Summary mode ─────────────────────────────────────────────────────────────── + + +def test_summary_list_of_dicts() -> None: + rows = [{"id": i, "amount": float(i * 10)} for i in range(100)] + frame = _transform(rows, "summary") + assert frame.response_mode == "summary" # type: ignore[union-attr] + assert len(frame.facts) > 0 # type: ignore[union-attr] + assert "Total rows: 100" in frame.facts # type: ignore[union-attr] + + +def test_summary_dict() -> None: + data = {"totals": {"USD": 1000.0}, "invoice_count": 200} + frame = _transform(data, "summary") + assert any("invoice_count" in f for f in frame.facts) # type: ignore[union-attr] + + +def test_summary_string() -> None: + frame = _transform("hello world", "summary") + assert frame.response_mode == "summary" # type: ignore[union-attr] + + +# ── Table mode ───────────────────────────────────────────────────────────────── + + +def test_table_row_cap() -> None: + rows = [{"id": i} for i in range(200)] + budgets = Budgets(max_rows=10) + frame = _transform(rows, "table", budgets=budgets) + assert len(frame.table_preview) <= 10 # type: ignore[union-attr] + + +def test_table_field_cap() -> None: + rows = [{"f" + str(j): j for j in range(50)}] + budgets = Budgets(max_fields=5) + frame = _transform(rows, "table", budgets=budgets) + assert all(len(r) <= 5 for r in frame.table_preview) # type: ignore[union-attr] + + +def test_table_max_rows_from_constraints() -> None: + rows = [{"id": i} for i in range(100)] + frame = _transform(rows, "table", constraints={"max_rows": 3}) + assert len(frame.table_preview) <= 3 # type: ignore[union-attr] + + +# ── Handle-only mode ─────────────────────────────────────────────────────────── + + +def test_handle_only() -> None: + frame = _transform([1, 2, 3], "handle_only") + assert frame.response_mode == "handle_only" # type: ignore[union-attr] + assert frame.handle is not None # type: ignore[union-attr] + assert frame.table_preview == [] # type: ignore[union-attr] + assert frame.facts == [] # type: ignore[union-attr] + + +# ── Raw mode ─────────────────────────────────────────────────────────────────── + + +def test_raw_mode_admin() -> None: + data = {"secret": "data"} + frame = _transform(data, "raw", principal_roles=["admin"]) + assert frame.response_mode == "raw" # type: ignore[union-attr] + assert frame.raw_data is not None # type: ignore[union-attr] + + +def test_raw_mode_non_admin_falls_back_to_summary() -> None: + data = {"secret": "data"} + frame = _transform(data, "raw", principal_roles=["reader"]) + assert frame.response_mode == "summary" # type: ignore[union-attr] + assert any("raw mode requires admin" in w for w in frame.warnings) # type: ignore[union-attr] + + +# ── Char budget ──────────────────────────────────────────────────────────────── + + +def test_char_budget_limits_facts() -> None: + big_string = "x" * 3000 + rows = [{"description": big_string} for _ in range(10)] + budgets = Budgets(max_chars=100) + frame = _transform(rows, "summary", budgets=budgets) + total = sum(len(f) for f in frame.facts) # type: ignore[union-attr] + assert total <= 200 # allow some slack for the budget check + + +# ── PII redaction ────────────────────────────────────────────────────────────── + + +def test_pii_allowed_fields_redaction() -> None: + rows = [{"id": 1, "email": "user@example.com", "amount": 100.0}] + frame = _transform( + rows, + "table", + constraints={"allowed_fields": ["id", "amount"]}, + ) + row = frame.table_preview[0] # type: ignore[union-attr] + assert "email" not in row + assert "id" in row + + +def test_redaction_warnings() -> None: + rows = [{"id": 1, "email": "test@example.com"}] + frame = _transform(rows, "table", constraints={"allowed_fields": ["id"]}) + assert any("email" in w for w in frame.warnings) # type: ignore[union-attr] + + +# ── max_depth ────────────────────────────────────────────────────────────────── + + +def test_max_depth_limiting() -> None: + deep = {"a": {"b": {"c": {"d": {"e": "deep"}}}}} + budgets = Budgets(max_depth=2) + frame = _transform(deep, "summary", budgets=budgets) + assert frame.response_mode == "summary" # type: ignore[union-attr] diff --git a/tests/test_handles.py b/tests/test_handles.py new file mode 100644 index 0000000..9e44348 --- /dev/null +++ b/tests/test_handles.py @@ -0,0 +1,108 @@ +"""Tests for HandleStore.""" + +from __future__ import annotations + +import datetime + +import pytest + +from agent_kernel import HandleExpired, HandleNotFound, HandleStore +from agent_kernel.models import Handle + + +@pytest.fixture() +def store() -> HandleStore: + return HandleStore(default_ttl_seconds=3600) + + +def test_store_and_retrieve(store: HandleStore) -> None: + data = [{"id": i} for i in range(10)] + handle = store.store("cap.x", data) + assert handle.total_rows == 10 + retrieved = store.get(handle.handle_id) + assert retrieved == data + + +def test_get_meta(store: HandleStore) -> None: + handle = store.store("cap.x", [1, 2, 3]) + meta = store.get_meta(handle.handle_id) + assert meta.handle_id == handle.handle_id + assert meta.capability_id == "cap.x" + + +def test_get_unknown_raises(store: HandleStore) -> None: + with pytest.raises(HandleNotFound): + store.get("nonexistent-handle-id") + + +def test_get_expired_raises(store: HandleStore) -> None: + handle = store.store("cap.x", [1, 2, 3], ttl_seconds=-1) + with pytest.raises(HandleExpired): + store.get(handle.handle_id) + + +def test_evict_expired(store: HandleStore) -> None: + store.store("cap.x", [1], ttl_seconds=-1) + store.store("cap.x", [2], ttl_seconds=-1) + store.store("cap.x", [3], ttl_seconds=3600) + evicted = store.evict_expired() + assert evicted == 2 + + +# ── Expand ───────────────────────────────────────────────────────────────────── + + +def _make_handle(store: HandleStore) -> Handle: + data = [ + {"id": i, "status": "paid" if i % 2 == 0 else "unpaid", "amount": float(i * 10)} + for i in range(20) + ] + return store.store("cap.x", data) + + +def test_expand_pagination(store: HandleStore) -> None: + handle = _make_handle(store) + frame = store.expand(handle, query={"offset": 5, "limit": 3}) + assert len(frame.table_preview) == 3 + assert frame.table_preview[0]["id"] == 5 + + +def test_expand_field_selection(store: HandleStore) -> None: + handle = _make_handle(store) + frame = store.expand(handle, query={"fields": ["id", "status"]}) + assert all(set(r.keys()) == {"id", "status"} for r in frame.table_preview) + + +def test_expand_filter(store: HandleStore) -> None: + handle = _make_handle(store) + frame = store.expand(handle, query={"filter": {"status": "paid"}}) + assert all(r["status"] == "paid" for r in frame.table_preview) + + +def test_expand_combined(store: HandleStore) -> None: + handle = _make_handle(store) + frame = store.expand( + handle, + query={"filter": {"status": "unpaid"}, "offset": 0, "limit": 2, "fields": ["id"]}, + ) + assert len(frame.table_preview) <= 2 + assert all("id" in r for r in frame.table_preview) + assert all("status" not in r for r in frame.table_preview) + + +def test_expand_expired_raises(store: HandleStore) -> None: + handle = store.store("cap.x", [1, 2, 3], ttl_seconds=-1) + with pytest.raises(HandleExpired): + store.expand(handle, query={}) + + +def test_expand_handle_not_found(store: HandleStore) -> None: + now = datetime.datetime.now(tz=datetime.timezone.utc) + fake_handle = Handle( + handle_id="fake-id", + capability_id="cap.x", + created_at=now, + expires_at=now + datetime.timedelta(hours=1), + ) + with pytest.raises(HandleNotFound): + store.expand(fake_handle, query={}) diff --git a/tests/test_kernel.py b/tests/test_kernel.py new file mode 100644 index 0000000..46a8490 --- /dev/null +++ b/tests/test_kernel.py @@ -0,0 +1,217 @@ +"""Integration tests for the Kernel (full flow).""" + +from __future__ import annotations + +import pytest + +from agent_kernel import ( + Capability, + CapabilityRegistry, + DriverError, + HMACTokenProvider, + InMemoryDriver, + Kernel, + PolicyDenied, + Principal, + SafetyClass, + StaticRouter, + TokenExpired, +) +from agent_kernel.models import CapabilityRequest + +# ── Full flow: request → grant → invoke → expand → explain ───────────────────── + + +@pytest.mark.asyncio +async def test_full_flow(kernel: Kernel, reader_principal: Principal) -> None: + requests = kernel.request_capabilities("list invoices") + assert len(requests) > 0 + + req = CapabilityRequest( + capability_id="billing.list_invoices", + goal="list all invoices", + ) + token = kernel.get_token(req, reader_principal, justification="") + assert token.capability_id == "billing.list_invoices" + + frame = await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + assert frame.response_mode == "summary" + assert frame.action_id != "" + + # explain + trace = kernel.explain(frame.action_id) + assert trace.capability_id == "billing.list_invoices" + assert trace.principal_id == reader_principal.principal_id + + # expand + assert frame.handle is not None + expanded = kernel.expand(frame.handle, query={"offset": 0, "limit": 2}) + assert len(expanded.table_preview) <= 2 + + +@pytest.mark.asyncio +async def test_invoke_table_mode(kernel: Kernel, reader_principal: Principal) -> None: + req = CapabilityRequest(capability_id="billing.list_invoices", goal="table") + token = kernel.get_token(req, reader_principal, justification="") + frame = await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + response_mode="table", + ) + assert frame.response_mode == "table" + + +@pytest.mark.asyncio +async def test_invoke_handle_only_mode(kernel: Kernel, reader_principal: Principal) -> None: + req = CapabilityRequest(capability_id="billing.list_invoices", goal="handle") + token = kernel.get_token(req, reader_principal, justification="") + frame = await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + response_mode="handle_only", + ) + assert frame.response_mode == "handle_only" + assert frame.handle is not None + + +# ── Denial flow ──────────────────────────────────────────────────────────────── + + +def test_grant_denied_write_no_role(kernel: Kernel, reader_principal: Principal) -> None: + req = CapabilityRequest( + capability_id="billing.update_invoice", + goal="update invoice", + ) + with pytest.raises(PolicyDenied): + kernel.get_token(req, reader_principal, justification="long enough justification here") + + +def test_grant_denied_destructive_no_admin(kernel: Kernel, writer_principal: Principal) -> None: + req = CapabilityRequest( + capability_id="billing.delete_invoice", + goal="delete invoice", + ) + with pytest.raises(PolicyDenied): + kernel.get_token(req, writer_principal, justification="long enough justification here") + + +def test_grant_allowed_write_writer_role(kernel: Kernel, writer_principal: Principal) -> None: + req = CapabilityRequest( + capability_id="billing.update_invoice", + goal="update invoice", + ) + token = kernel.get_token( + req, writer_principal, justification="this is a long enough justification" + ) + assert token.capability_id == "billing.update_invoice" + + +# ── Expired token flow ───────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_invoke_expired_token(kernel: Kernel, reader_principal: Principal) -> None: + token_provider = HMACTokenProvider(secret="test-secret-do-not-use-in-prod") + token = token_provider.issue( + "billing.list_invoices", + reader_principal.principal_id, + ttl_seconds=-1, + ) + with pytest.raises(TokenExpired): + await kernel.invoke( + token, + principal=reader_principal, + args={"operation": "billing.list_invoices"}, + ) + + +# ── Fallback driver flow ─────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_fallback_driver_flow() -> None: + """If the first driver fails, the kernel tries the next one.""" + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="test.cap", + name="Test", + description="Test capability", + safety_class=SafetyClass.READ, + ) + ) + + primary = InMemoryDriver(driver_id="primary") + # primary raises DriverError + primary.register_handler( + "test.cap", lambda ctx: (_ for _ in ()).throw(DriverError("primary fail")) + ) + + fallback = InMemoryDriver(driver_id="fallback") + fallback.register_handler("test.cap", lambda ctx: {"from": "fallback"}) + + router = StaticRouter(routes={"test.cap": ["primary", "fallback"]}) + token_provider = HMACTokenProvider(secret="test-secret") + k = Kernel(registry=registry, router=router, token_provider=token_provider) + k.register_driver(primary) + k.register_driver(fallback) + + principal = Principal(principal_id="u1") + token = token_provider.issue("test.cap", "u1") + frame = await k.invoke(token, principal=principal, args={}) + assert frame.response_mode == "summary" + trace = k.explain(frame.action_id) + assert trace.driver_id == "fallback" + + +@pytest.mark.asyncio +async def test_all_drivers_fail_raises_driver_error() -> None: + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="test.fail", + name="Fail", + description="Always fails", + safety_class=SafetyClass.READ, + ) + ) + bad_driver = InMemoryDriver(driver_id="bad") + bad_driver.register_handler( + "test.fail", lambda ctx: (_ for _ in ()).throw(DriverError("always fail")) + ) + + router = StaticRouter(routes={"test.fail": ["bad"]}) + token_provider = HMACTokenProvider(secret="test-secret") + k = Kernel(registry=registry, router=router, token_provider=token_provider) + k.register_driver(bad_driver) + + principal = Principal(principal_id="u1") + token = token_provider.issue("test.fail", "u1") + with pytest.raises(DriverError): + await k.invoke(token, principal=principal, args={}) + + +# ── Confused-deputy prevention ───────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_confused_deputy_prevention(kernel: Kernel, reader_principal: Principal) -> None: + """A token issued for one principal cannot be used by another.""" + req = CapabilityRequest(capability_id="billing.list_invoices", goal="test") + token = kernel.get_token(req, reader_principal, justification="") + + other_principal = Principal(principal_id="attacker-999", roles=["reader"]) + from agent_kernel import TokenScopeError + + with pytest.raises(TokenScopeError): + await kernel.invoke( + token, + principal=other_principal, + args={"operation": "billing.list_invoices"}, + ) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..cf5b024 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,154 @@ +"""Tests for core dataclasses and models.""" + +from __future__ import annotations + +import datetime + +from agent_kernel.enums import SafetyClass, SensitivityTag +from agent_kernel.models import ( + ActionTrace, + Budgets, + Capability, + CapabilityRequest, + Frame, + Handle, + ImplementationRef, + PolicyDecision, + Principal, + RawResult, + RoutePlan, +) +from agent_kernel.tokens import CapabilityToken + + +def test_capability_construction() -> None: + cap = Capability( + capability_id="test.cap", + name="Test Cap", + description="A test capability", + safety_class=SafetyClass.READ, + ) + assert cap.capability_id == "test.cap" + assert cap.safety_class == SafetyClass.READ + assert cap.sensitivity == SensitivityTag.NONE + assert cap.allowed_fields == [] + assert cap.tags == [] + assert cap.impl is None + + +def test_capability_with_all_fields() -> None: + impl = ImplementationRef(driver_id="memory", operation="op1") + cap = Capability( + capability_id="test.full", + name="Full Cap", + description="Full capability", + safety_class=SafetyClass.WRITE, + sensitivity=SensitivityTag.PII, + allowed_fields=["id", "name"], + tags=["tag1", "tag2"], + impl=impl, + ) + assert cap.impl is not None + assert cap.impl.driver_id == "memory" + assert cap.impl.operation == "op1" + assert cap.tags == ["tag1", "tag2"] + + +def test_principal_defaults() -> None: + p = Principal(principal_id="user-001") + assert p.roles == [] + assert p.attributes == {} + + +def test_capability_request() -> None: + req = CapabilityRequest( + capability_id="test.cap", + goal="I need to list things", + constraints={"max_rows": 10}, + ) + assert req.capability_id == "test.cap" + assert req.constraints["max_rows"] == 10 + + +def test_policy_decision() -> None: + dec = PolicyDecision(allowed=True, reason="OK", constraints={"max_rows": 50}) + assert dec.allowed is True + assert dec.constraints["max_rows"] == 50 + + +def test_raw_result() -> None: + rr = RawResult(capability_id="cap.x", data=[1, 2, 3]) + assert rr.data == [1, 2, 3] + assert rr.metadata == {} + + +def test_frame_defaults() -> None: + frame = Frame(action_id="a1", capability_id="cap.x", response_mode="summary") + assert frame.facts == [] + assert frame.table_preview == [] + assert frame.handle is None + assert frame.warnings == [] + + +def test_handle_construction() -> None: + now = datetime.datetime.now(tz=datetime.timezone.utc) + h = Handle( + handle_id="h1", + capability_id="cap.x", + created_at=now, + expires_at=now + datetime.timedelta(hours=1), + total_rows=100, + ) + assert h.handle_id == "h1" + assert h.total_rows == 100 + + +def test_budgets_defaults() -> None: + b = Budgets() + assert b.max_rows == 50 + assert b.max_fields == 20 + assert b.max_chars == 4000 + assert b.max_depth == 3 + + +def test_action_trace() -> None: + now = datetime.datetime.now(tz=datetime.timezone.utc) + trace = ActionTrace( + action_id="act-1", + capability_id="cap.x", + principal_id="user-1", + token_id="tok-1", + invoked_at=now, + args={"a": 1}, + response_mode="summary", + driver_id="memory", + ) + assert trace.action_id == "act-1" + assert trace.error is None + assert trace.handle_id is None + + +def test_route_plan() -> None: + plan = RoutePlan(capability_id="cap.x", driver_ids=["memory", "http"]) + assert plan.driver_ids == ["memory", "http"] + + +def test_capability_token_from_to_dict() -> None: + now = datetime.datetime.now(tz=datetime.timezone.utc) + token = CapabilityToken( + token_id="tok-1", + capability_id="cap.x", + principal_id="user-1", + issued_at=now, + expires_at=now + datetime.timedelta(hours=1), + constraints={"max_rows": 10}, + audit_id="audit-1", + signature="sig", + ) + d = token.to_dict() + assert d["token_id"] == "tok-1" + assert d["signature"] == "sig" + + restored = CapabilityToken.from_dict(d) + assert restored.token_id == "tok-1" + assert restored.constraints == {"max_rows": 10} diff --git a/tests/test_policy.py b/tests/test_policy.py new file mode 100644 index 0000000..ba217cb --- /dev/null +++ b/tests/test_policy.py @@ -0,0 +1,194 @@ +"""Tests for DefaultPolicyEngine.""" + +from __future__ import annotations + +import pytest + +from agent_kernel import ( + Capability, + DefaultPolicyEngine, + PolicyDenied, + Principal, + SafetyClass, + SensitivityTag, +) +from agent_kernel.models import CapabilityRequest + + +def _req(cap_id: str, **constraints: object) -> CapabilityRequest: + return CapabilityRequest(capability_id=cap_id, goal="test", constraints=dict(constraints)) + + +def _cap( + cap_id: str, + safety: SafetyClass, + sensitivity: SensitivityTag = SensitivityTag.NONE, + allowed_fields: list[str] | None = None, +) -> Capability: + return Capability( + capability_id=cap_id, + name=cap_id, + description="test", + safety_class=safety, + sensitivity=sensitivity, + allowed_fields=allowed_fields or [], + ) + + +engine = DefaultPolicyEngine() + + +# ── READ ─────────────────────────────────────────────────────────────────────── + + +def test_read_allowed_no_roles() -> None: + p = Principal(principal_id="u1") + dec = engine.evaluate(_req("cap.r"), _cap("cap.r", SafetyClass.READ), p, justification="") + assert dec.allowed is True + + +def test_read_sets_max_rows_user() -> None: + p = Principal(principal_id="u1", roles=["reader"]) + dec = engine.evaluate(_req("cap.r"), _cap("cap.r", SafetyClass.READ), p, justification="") + assert dec.constraints["max_rows"] == 50 + + +def test_read_sets_max_rows_service() -> None: + p = Principal(principal_id="svc1", roles=["service"]) + dec = engine.evaluate(_req("cap.r"), _cap("cap.r", SafetyClass.READ), p, justification="") + assert dec.constraints["max_rows"] == 500 + + +def test_read_respects_tighter_constraint() -> None: + p = Principal(principal_id="u1") + dec = engine.evaluate( + _req("cap.r", max_rows=5), _cap("cap.r", SafetyClass.READ), p, justification="" + ) + assert dec.constraints["max_rows"] == 5 + + +def test_read_tighter_constraint_cannot_exceed_cap() -> None: + p = Principal(principal_id="u1") + dec = engine.evaluate( + _req("cap.r", max_rows=9999), _cap("cap.r", SafetyClass.READ), p, justification="" + ) + assert dec.constraints["max_rows"] == 50 + + +# ── WRITE ────────────────────────────────────────────────────────────────────── + + +def test_write_denied_no_role() -> None: + p = Principal(principal_id="u1", roles=["reader"]) + with pytest.raises(PolicyDenied, match="writer.*admin"): + engine.evaluate( + _req("cap.w"), + _cap("cap.w", SafetyClass.WRITE), + p, + justification="long enough justification here", + ) + + +def test_write_denied_short_justification() -> None: + p = Principal(principal_id="u1", roles=["writer"]) + with pytest.raises(PolicyDenied, match="justification"): + engine.evaluate( + _req("cap.w"), _cap("cap.w", SafetyClass.WRITE), p, justification="too short" + ) + + +def test_write_allowed_writer_role() -> None: + p = Principal(principal_id="u1", roles=["writer"]) + dec = engine.evaluate( + _req("cap.w"), + _cap("cap.w", SafetyClass.WRITE), + p, + justification="this is a long enough justification string", + ) + assert dec.allowed is True + + +def test_write_allowed_admin_role() -> None: + p = Principal(principal_id="u1", roles=["admin"]) + dec = engine.evaluate( + _req("cap.w"), + _cap("cap.w", SafetyClass.WRITE), + p, + justification="this is a long enough justification string", + ) + assert dec.allowed is True + + +# ── DESTRUCTIVE ──────────────────────────────────────────────────────────────── + + +def test_destructive_denied_no_admin() -> None: + p = Principal(principal_id="u1", roles=["writer"]) + with pytest.raises(PolicyDenied, match="admin"): + engine.evaluate( + _req("cap.d"), + _cap("cap.d", SafetyClass.DESTRUCTIVE), + p, + justification="long enough justification", + ) + + +def test_destructive_allowed_admin() -> None: + p = Principal(principal_id="u1", roles=["admin"]) + dec = engine.evaluate( + _req("cap.d"), + _cap("cap.d", SafetyClass.DESTRUCTIVE), + p, + justification="long enough justification", + ) + assert dec.allowed is True + + +# ── PII / PCI ────────────────────────────────────────────────────────────────── + + +def test_pii_requires_tenant() -> None: + p = Principal(principal_id="u1", roles=["reader"]) + cap = _cap("cap.pii", SafetyClass.READ, SensitivityTag.PII) + with pytest.raises(PolicyDenied, match="tenant"): + engine.evaluate(_req("cap.pii"), cap, p, justification="") + + +def test_pii_allowed_with_tenant() -> None: + p = Principal(principal_id="u1", roles=["reader"], attributes={"tenant": "acme"}) + cap = _cap("cap.pii", SafetyClass.READ, SensitivityTag.PII) + dec = engine.evaluate(_req("cap.pii"), cap, p, justification="") + assert dec.allowed is True + + +def test_pii_enforces_allowed_fields() -> None: + p = Principal(principal_id="u1", roles=["reader"], attributes={"tenant": "acme"}) + cap = _cap("cap.pii", SafetyClass.READ, SensitivityTag.PII, allowed_fields=["id", "name"]) + dec = engine.evaluate(_req("cap.pii"), cap, p, justification="") + assert dec.constraints.get("allowed_fields") == ["id", "name"] + + +def test_pii_reader_skips_allowed_fields() -> None: + p = Principal(principal_id="u1", roles=["reader", "pii_reader"], attributes={"tenant": "acme"}) + cap = _cap("cap.pii", SafetyClass.READ, SensitivityTag.PII, allowed_fields=["id", "name"]) + dec = engine.evaluate(_req("cap.pii"), cap, p, justification="") + assert "allowed_fields" not in dec.constraints + + +def test_pci_requires_tenant() -> None: + p = Principal(principal_id="u1", roles=["reader"]) + cap = _cap("cap.pci", SafetyClass.READ, SensitivityTag.PCI) + with pytest.raises(PolicyDenied, match="tenant"): + engine.evaluate(_req("cap.pci"), cap, p, justification="") + + +# ── Confused-deputy binding (via token) ──────────────────────────────────────── + + +def test_max_rows_enforcement() -> None: + """max_rows in constraints is capped by the policy ceiling.""" + p = Principal(principal_id="u1") + dec = engine.evaluate( + _req("cap.r", max_rows=200), _cap("cap.r", SafetyClass.READ), p, justification="" + ) + assert dec.constraints["max_rows"] == 50 diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..72fd5a3 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,106 @@ +"""Tests for CapabilityRegistry.""" + +from __future__ import annotations + +import pytest + +from agent_kernel import Capability, CapabilityNotFound, CapabilityRegistry, SafetyClass + + +def _make_cap(cap_id: str, **kwargs: object) -> Capability: + defaults: dict[str, object] = { + "name": cap_id.replace(".", " ").title(), + "description": f"Description for {cap_id}", + "safety_class": SafetyClass.READ, + } + defaults.update(kwargs) + return Capability(capability_id=cap_id, **defaults) # type: ignore[arg-type] + + +def test_register_and_get() -> None: + reg = CapabilityRegistry() + cap = _make_cap("test.cap") + reg.register(cap) + assert reg.get("test.cap") is cap + + +def test_register_duplicate_raises() -> None: + reg = CapabilityRegistry() + reg.register(_make_cap("test.dup")) + with pytest.raises(ValueError, match="already registered"): + reg.register(_make_cap("test.dup")) + + +def test_get_unknown_raises() -> None: + reg = CapabilityRegistry() + with pytest.raises(CapabilityNotFound): + reg.get("does.not.exist") + + +def test_register_many() -> None: + reg = CapabilityRegistry() + caps = [_make_cap(f"cap.{i}") for i in range(5)] + reg.register_many(caps) + assert len(reg.list_all()) == 5 + + +def test_list_all_order() -> None: + reg = CapabilityRegistry() + for i in range(3): + reg.register(_make_cap(f"cap.{i}")) + ids = [c.capability_id for c in reg.list_all()] + assert ids == ["cap.0", "cap.1", "cap.2"] + + +def test_search_basic(registry: CapabilityRegistry) -> None: + results = registry.search("list invoices") + assert len(results) > 0 + ids = [r.capability_id for r in results] + assert "billing.list_invoices" in ids + + +def test_search_returns_capabilityrequest(registry: CapabilityRegistry) -> None: + from agent_kernel.models import CapabilityRequest + + results = registry.search("billing invoice") + assert all(isinstance(r, CapabilityRequest) for r in results) + + +def test_search_empty_goal(registry: CapabilityRegistry) -> None: + results = registry.search("") + assert results == [] + + +def test_search_no_matches(registry: CapabilityRegistry) -> None: + results = registry.search("zzz completely unrelated xyz") + assert results == [] + + +def test_search_max_results() -> None: + reg = CapabilityRegistry() + for i in range(20): + reg.register(_make_cap(f"search.cap{i}", description=f"billing invoice item {i}")) + results = reg.search("billing invoice", max_results=5) + assert len(results) <= 5 + + +def test_search_keyword_in_tags() -> None: + reg = CapabilityRegistry() + reg.register( + Capability( + capability_id="tag.test", + name="Tag Test", + description="Unrelated description", + safety_class=SafetyClass.READ, + tags=["uniquetag123"], + ) + ) + results = reg.search("uniquetag123") + assert len(results) == 1 + assert results[0].capability_id == "tag.test" + + +def test_search_goal_preserved(registry: CapabilityRegistry) -> None: + goal = "list all billing invoices please" + results = registry.search(goal) + assert all(r.goal == goal for r in results) diff --git a/tests/test_router.py b/tests/test_router.py new file mode 100644 index 0000000..1777127 --- /dev/null +++ b/tests/test_router.py @@ -0,0 +1,39 @@ +"""Tests for StaticRouter.""" + +from __future__ import annotations + +from agent_kernel import StaticRouter + + +def test_explicit_route() -> None: + router = StaticRouter(routes={"cap.x": ["http", "memory"]}) + plan = router.route("cap.x") + assert plan.driver_ids == ["http", "memory"] + assert plan.capability_id == "cap.x" + + +def test_fallback_route() -> None: + router = StaticRouter(routes={}, fallback=["memory"]) + plan = router.route("cap.unknown") + assert plan.driver_ids == ["memory"] + + +def test_default_fallback() -> None: + router = StaticRouter() + plan = router.route("anything") + assert "memory" in plan.driver_ids + + +def test_add_route() -> None: + router = StaticRouter() + router.add_route("cap.new", ["http"]) + plan = router.route("cap.new") + assert plan.driver_ids == ["http"] + + +def test_route_returns_copy() -> None: + """Mutating the returned driver_ids should not affect the router.""" + router = StaticRouter(routes={"cap.x": ["memory"]}) + plan = router.route("cap.x") + plan.driver_ids.append("corrupted") + assert router.route("cap.x").driver_ids == ["memory"] diff --git a/tests/test_tokens.py b/tests/test_tokens.py new file mode 100644 index 0000000..3b0fd97 --- /dev/null +++ b/tests/test_tokens.py @@ -0,0 +1,105 @@ +"""Tests for HMACTokenProvider.""" + +from __future__ import annotations + +import pytest + +from agent_kernel import ( + HMACTokenProvider, + TokenExpired, + TokenInvalid, + TokenScopeError, +) + + +@pytest.fixture() +def provider() -> HMACTokenProvider: + return HMACTokenProvider(secret="test-secret-12345") + + +def test_issue_returns_token(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1") + assert token.capability_id == "cap.x" + assert token.principal_id == "user-1" + assert token.signature != "" + assert token.token_id != "" + + +def test_verify_valid_token(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1") + # Should not raise + provider.verify(token, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_verify_expired_token(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1", ttl_seconds=-1) + with pytest.raises(TokenExpired): + provider.verify(token, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_verify_tampered_signature(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1") + # Flip the first character of the signature + flipped = ("a" if token.signature[0] != "a" else "b") + token.signature[1:] + from dataclasses import replace + + tampered = replace(token, signature=flipped) + with pytest.raises(TokenInvalid): + provider.verify(tampered, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_verify_wrong_principal(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1") + with pytest.raises(TokenScopeError, match="principal"): + provider.verify(token, expected_principal_id="user-2", expected_capability_id="cap.x") + + +def test_verify_wrong_capability(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1") + with pytest.raises(TokenScopeError, match="capability"): + provider.verify(token, expected_principal_id="user-1", expected_capability_id="cap.y") + + +def test_token_with_constraints(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1", constraints={"max_rows": 10}) + assert token.constraints["max_rows"] == 10 + # Verification should still pass + provider.verify(token, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_token_serialization_roundtrip(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1", constraints={"foo": "bar"}) + d = token.to_dict() + from agent_kernel.tokens import CapabilityToken + + restored = CapabilityToken.from_dict(d) + assert restored.token_id == token.token_id + assert restored.signature == token.signature + # Verification should still pass on the restored token + provider.verify(restored, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_tamper_constraints_invalidates_token(provider: HMACTokenProvider) -> None: + token = provider.issue("cap.x", "user-1", constraints={"max_rows": 10}) + d = token.to_dict() + d["constraints"]["max_rows"] = 9999 # tamper + from agent_kernel.tokens import CapabilityToken + + tampered = CapabilityToken.from_dict(d) + with pytest.raises(TokenInvalid): + provider.verify(tampered, expected_principal_id="user-1", expected_capability_id="cap.x") + + +def test_dev_secret_warning(caplog: pytest.LogCaptureFixture) -> None: + """A provider with no secret should generate a warning.""" + import logging + + # Reset dev secret to force warning + import agent_kernel.tokens as tok_mod + + tok_mod._DEV_SECRET = None + provider_no_secret = HMACTokenProvider(secret=None) + with caplog.at_level(logging.WARNING, logger="agent_kernel.tokens"): + token = provider_no_secret.issue("cap.x", "user-1") + assert "AGENT_KERNEL_SECRET" in caplog.text + assert token.signature != "" diff --git a/tests/test_trace.py b/tests/test_trace.py new file mode 100644 index 0000000..5804059 --- /dev/null +++ b/tests/test_trace.py @@ -0,0 +1,58 @@ +"""Tests for TraceStore.""" + +from __future__ import annotations + +import datetime + +import pytest + +from agent_kernel import TraceStore +from agent_kernel.errors import AgentKernelError +from agent_kernel.models import ActionTrace + + +def _trace(action_id: str = "act-1") -> ActionTrace: + return ActionTrace( + action_id=action_id, + capability_id="cap.x", + principal_id="u1", + token_id="tok-1", + invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), + args={"a": 1}, + response_mode="summary", + driver_id="memory", + ) + + +def test_record_and_get() -> None: + store = TraceStore() + t = _trace("act-1") + store.record(t) + result = store.get("act-1") + assert result is t + + +def test_get_unknown_raises() -> None: + store = TraceStore() + with pytest.raises(AgentKernelError, match="act-missing"): + store.get("act-missing") + + +def test_list_all() -> None: + store = TraceStore() + for i in range(3): + store.record(_trace(f"act-{i}")) + all_traces = store.list_all() + assert len(all_traces) == 3 + assert [t.action_id for t in all_traces] == ["act-0", "act-1", "act-2"] + + +def test_explain_returns_consistent_data() -> None: + store = TraceStore() + t = _trace("act-explain") + store.record(t) + result = store.get("act-explain") + assert result.capability_id == "cap.x" + assert result.principal_id == "u1" + assert result.driver_id == "memory" + assert result.args == {"a": 1} From b8cfb2223b16ec54e45c394f9956d580983aaa34 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:03:38 +0000 Subject: [PATCH 03/19] fix: add explicit permissions to CI workflow (CodeQL alert) Co-authored-by: dgenio <12731907+dgenio@users.noreply.github.com> --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fc901e6..27613b8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,8 @@ jobs: test: name: "Python ${{ matrix.python-version }}" runs-on: ubuntu-latest + permissions: + contents: read strategy: matrix: python-version: ["3.10", "3.11", "3.12"] From 68760bc254a5f048cf987a8d4340c2ec9b11f890 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:31:43 +0000 Subject: [PATCH 04/19] fix: replace bare ValueError with CapabilityAlreadyRegistered in registry --- src/agent_kernel/__init__.py | 2 ++ src/agent_kernel/errors.py | 4 ++++ src/agent_kernel/registry.py | 6 +++--- tests/test_registry.py | 10 ++++++++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/agent_kernel/__init__.py b/src/agent_kernel/__init__.py index acc22e6..69efbc1 100644 --- a/src/agent_kernel/__init__.py +++ b/src/agent_kernel/__init__.py @@ -41,6 +41,7 @@ from .enums import SafetyClass, SensitivityTag from .errors import ( AgentKernelError, + CapabilityAlreadyRegistered, CapabilityNotFound, DriverError, FirewallError, @@ -105,6 +106,7 @@ "SensitivityTag", # errors "AgentKernelError", + "CapabilityAlreadyRegistered", "CapabilityNotFound", "DriverError", "FirewallError", diff --git a/src/agent_kernel/errors.py b/src/agent_kernel/errors.py index 31bd70d..5cf6ab8 100644 --- a/src/agent_kernel/errors.py +++ b/src/agent_kernel/errors.py @@ -44,6 +44,10 @@ class FirewallError(AgentKernelError): # ── Registry / lookup errors ────────────────────────────────────────────────── +class CapabilityAlreadyRegistered(AgentKernelError): + """Raised when a capability with the same ID is already registered.""" + + class CapabilityNotFound(AgentKernelError): """Raised when a capability ID is not found in the registry.""" diff --git a/src/agent_kernel/registry.py b/src/agent_kernel/registry.py index 48515ff..8687ae8 100644 --- a/src/agent_kernel/registry.py +++ b/src/agent_kernel/registry.py @@ -4,7 +4,7 @@ import re -from .errors import CapabilityNotFound +from .errors import CapabilityAlreadyRegistered, CapabilityNotFound from .models import Capability, CapabilityRequest @@ -27,10 +27,10 @@ def register(self, capability: Capability) -> None: capability: The :class:`Capability` to register. Raises: - ValueError: If a capability with the same ID is already registered. + CapabilityAlreadyRegistered: If a capability with the same ID is already registered. """ if capability.capability_id in self._store: - raise ValueError( + raise CapabilityAlreadyRegistered( f"Capability '{capability.capability_id}' is already registered. " "Use a unique capability_id." ) diff --git a/tests/test_registry.py b/tests/test_registry.py index 72fd5a3..6e63597 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -4,7 +4,13 @@ import pytest -from agent_kernel import Capability, CapabilityNotFound, CapabilityRegistry, SafetyClass +from agent_kernel import ( + Capability, + CapabilityAlreadyRegistered, + CapabilityNotFound, + CapabilityRegistry, + SafetyClass, +) def _make_cap(cap_id: str, **kwargs: object) -> Capability: @@ -27,7 +33,7 @@ def test_register_and_get() -> None: def test_register_duplicate_raises() -> None: reg = CapabilityRegistry() reg.register(_make_cap("test.dup")) - with pytest.raises(ValueError, match="already registered"): + with pytest.raises(CapabilityAlreadyRegistered, match="already registered"): reg.register(_make_cap("test.dup")) From aa690f10022dcf03b8ab2627ee9f4f9bf48db37b Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:34:48 +0000 Subject: [PATCH 05/19] fix: wire default_timeout as fallback in HTTPDriver.execute() --- src/agent_kernel/drivers/http.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agent_kernel/drivers/http.py b/src/agent_kernel/drivers/http.py index d586c01..7ba8af9 100644 --- a/src/agent_kernel/drivers/http.py +++ b/src/agent_kernel/drivers/http.py @@ -19,7 +19,8 @@ class HTTPEndpoint: url: str method: str = "GET" headers: dict[str, str] = field(default_factory=dict) - timeout: float = 30.0 + timeout: float | None = None + """Per-endpoint timeout in seconds. Falls back to the driver's ``default_timeout``.""" class HTTPDriver: @@ -87,8 +88,12 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: else: json_body = {k: v for k, v in ctx.args.items() if k != "operation"} + effective_timeout = ( + endpoint.timeout if endpoint.timeout is not None else self._default_timeout + ) + try: - async with httpx.AsyncClient(headers=headers, timeout=endpoint.timeout) as client: + async with httpx.AsyncClient(headers=headers, timeout=effective_timeout) as client: if endpoint.method.upper() == "GET": response = await client.get(endpoint.url, params=params) elif endpoint.method.upper() == "POST": From 7fedd662b31376fc0a33cffa045d80ee47ef8ee5 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:38:48 +0000 Subject: [PATCH 06/19] fix: prevent IndexError on empty rows in HandleStore.expand() --- src/agent_kernel/handles.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/agent_kernel/handles.py b/src/agent_kernel/handles.py index ed06878..57ba367 100644 --- a/src/agent_kernel/handles.py +++ b/src/agent_kernel/handles.py @@ -154,15 +154,18 @@ def expand( for r in rows ] + if not rows: + table_preview: list[Any] = [] + elif isinstance(rows[0], dict): + table_preview = rows + else: + table_preview = [{"value": r} for r in rows] + return Frame( action_id=action_id, capability_id=handle.capability_id, response_mode=response_mode, - table_preview=rows - if isinstance(rows[0], dict) - else [{"value": r} for r in rows] - if rows - else [], + table_preview=table_preview, handle=handle, provenance=Provenance( capability_id=handle.capability_id, From 32da6bce8d7dc80a628252f6ea782b68abd25513 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:43:22 +0000 Subject: [PATCH 07/19] fix: include full token and audit_id in CapabilityGrant --- src/agent_kernel/kernel.py | 3 ++- src/agent_kernel/models.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index f7c1a5d..f006af3 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -129,7 +129,8 @@ def grant_capability( request=request, principal=principal, decision=decision, - token_id=token.token_id, + token=token, + audit_id=audit_id, ) def get_token( diff --git a/src/agent_kernel/models.py b/src/agent_kernel/models.py index 60f9ba9..aa1d921 100644 --- a/src/agent_kernel/models.py +++ b/src/agent_kernel/models.py @@ -8,10 +8,13 @@ import datetime from dataclasses import dataclass, field -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal from .enums import SafetyClass, SensitivityTag +if TYPE_CHECKING: + from .tokens import CapabilityToken + ResponseMode = Literal["summary", "table", "handle_only", "raw"] @@ -116,8 +119,11 @@ class CapabilityGrant: decision: PolicyDecision """The policy decision that led to this grant.""" - token_id: str - """The token's unique identifier.""" + token: CapabilityToken + """The signed capability token issued for this grant.""" + + audit_id: str + """Unique audit identifier embedded in the token for traceability.""" # ── Routing ─────────────────────────────────────────────────────────────────── From b27b1a9e9c09812bad3e2aeca0958a9a4491a049 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:48:31 +0000 Subject: [PATCH 08/19] fix: validate max_rows constraint and raise PolicyDenied on invalid input --- src/agent_kernel/policy.py | 9 ++++++++- tests/test_policy.py | 21 +++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/agent_kernel/policy.py b/src/agent_kernel/policy.py index 7aecf41..1b64322 100644 --- a/src/agent_kernel/policy.py +++ b/src/agent_kernel/policy.py @@ -124,7 +124,14 @@ def evaluate( max_rows = _MAX_ROWS_SERVICE if "service" in roles else _MAX_ROWS_USER # Respect any tighter constraint from the request itself. if "max_rows" in constraints: - constraints["max_rows"] = min(int(constraints["max_rows"]), max_rows) + try: + requested = int(constraints["max_rows"]) + except (TypeError, ValueError) as exc: + raise PolicyDenied( + f"Invalid 'max_rows' constraint: {constraints['max_rows']!r} " + "is not a valid integer." + ) from exc + constraints["max_rows"] = min(max(requested, 0), max_rows) else: constraints["max_rows"] = max_rows diff --git a/tests/test_policy.py b/tests/test_policy.py index ba217cb..91c5e11 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -192,3 +192,24 @@ def test_max_rows_enforcement() -> None: _req("cap.r", max_rows=200), _cap("cap.r", SafetyClass.READ), p, justification="" ) assert dec.constraints["max_rows"] == 50 + + +def test_max_rows_invalid_raises_policy_denied() -> None: + """Non-numeric max_rows raises PolicyDenied, not bare ValueError.""" + p = Principal(principal_id="u1") + with pytest.raises(PolicyDenied, match="Invalid 'max_rows'"): + engine.evaluate( + _req("cap.r", max_rows="abc"), + _cap("cap.r", SafetyClass.READ), + p, + justification="", + ) + + +def test_max_rows_negative_clamped_to_zero() -> None: + """Negative max_rows is clamped to 0.""" + p = Principal(principal_id="u1") + dec = engine.evaluate( + _req("cap.r", max_rows=-10), _cap("cap.r", SafetyClass.READ), p, justification="" + ) + assert dec.constraints["max_rows"] == 0 From 5dc0b012a03c3bdc7e8e3e6821b73b6cc983880e Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 05:56:03 +0000 Subject: [PATCH 09/19] chore: remove unused _deep_copy_truncated and import copy --- src/agent_kernel/firewall/transform.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/agent_kernel/firewall/transform.py b/src/agent_kernel/firewall/transform.py index f570f63..c327cad 100644 --- a/src/agent_kernel/firewall/transform.py +++ b/src/agent_kernel/firewall/transform.py @@ -2,7 +2,6 @@ from __future__ import annotations -import copy import datetime import json from typing import Any @@ -188,17 +187,3 @@ def _cap_facts(facts: list[str], max_chars: int) -> list[str]: break result.append(fact) return result - - -def _deep_copy_truncated(data: Any, *, max_depth: int, depth: int = 0) -> Any: - """Deep-copy data, stopping recursion at *max_depth*.""" - if depth >= max_depth: - return repr(data)[:100] - if isinstance(data, dict): - return { - k: _deep_copy_truncated(v, max_depth=max_depth, depth=depth + 1) - for k, v in data.items() - } - if isinstance(data, list): - return [_deep_copy_truncated(v, max_depth=max_depth, depth=depth + 1) for v in data] - return copy.copy(data) if not isinstance(data, (int, float, str, bool, type(None))) else data From 8b7fa7994762c1100b826f8db3a44e9894c06105 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:01:14 +0000 Subject: [PATCH 10/19] fix: avoid json.loads on truncated JSON in raw mode transform --- src/agent_kernel/firewall/transform.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/agent_kernel/firewall/transform.py b/src/agent_kernel/firewall/transform.py index c327cad..43e7d8b 100644 --- a/src/agent_kernel/firewall/transform.py +++ b/src/agent_kernel/firewall/transform.py @@ -101,12 +101,17 @@ def transform( warnings.append("raw mode requires admin role; falling back to summary.") response_mode = "summary" else: - raw_str = _truncate_str(json.dumps(data, default=str), self._budgets.max_chars) + raw_size = len(json.dumps(data, default=str)) + if raw_size > self._budgets.max_chars: + warnings.append( + f"raw output ({raw_size} chars) exceeds budget " + f"({self._budgets.max_chars} chars); data returned untruncated." + ) return Frame( action_id=action_id, capability_id=raw.capability_id, response_mode="raw", - raw_data=json.loads(raw_str) if raw_str else data, + raw_data=data, handle=handle, warnings=warnings, provenance=provenance, From 994260fda49bfae657e08a992d8a98fc94dde33f Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:02:48 +0000 Subject: [PATCH 11/19] fix: record effective response_mode in ActionTrace instead of requested --- src/agent_kernel/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index f006af3..813086b 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -275,7 +275,7 @@ async def invoke( token_id=token.token_id, invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), args=args, - response_mode=response_mode, + response_mode=frame.response_mode, driver_id=used_driver_id, handle_id=handle.handle_id if handle else None, ) From 1cc59299cbbdbc604598dfe8b32771edc1c08f2b Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:12:13 +0000 Subject: [PATCH 12/19] docs: fix Kernel docstring example to include principal arg --- src/agent_kernel/kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index 813086b..0e1b8bc 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -42,7 +42,7 @@ class Kernel: requests = kernel.request_capabilities("list invoices") grant = kernel.grant_capability(requests[0], principal, justification="...") - frame = await kernel.invoke(grant.token, args={"operation": "list_invoices"}) + frame = await kernel.invoke(grant.token, principal=principal, args={"operation": "list_invoices"}) """ def __init__( From 610a18f01ebd475afc504d06f468d3261e19b359 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:33:02 +0000 Subject: [PATCH 13/19] fix: require justification for DESTRUCTIVE operations --- src/agent_kernel/policy.py | 6 ++++++ tests/test_policy.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/agent_kernel/policy.py b/src/agent_kernel/policy.py index 1b64322..14268a2 100644 --- a/src/agent_kernel/policy.py +++ b/src/agent_kernel/policy.py @@ -100,6 +100,12 @@ def evaluate( ) elif capability.safety_class == SafetyClass.DESTRUCTIVE: + if len(justification) < _MIN_JUSTIFICATION: + raise PolicyDenied( + f"DESTRUCTIVE capabilities require a justification of at least " + f"{_MIN_JUSTIFICATION} characters. " + f"Got {len(justification)} characters." + ) if "admin" not in roles: raise PolicyDenied( f"DESTRUCTIVE capabilities require the 'admin' role. " diff --git a/tests/test_policy.py b/tests/test_policy.py index 91c5e11..e00127d 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -122,6 +122,17 @@ def test_write_allowed_admin_role() -> None: # ── DESTRUCTIVE ──────────────────────────────────────────────────────────────── +def test_destructive_denied_short_justification() -> None: + p = Principal(principal_id="u1", roles=["admin"]) + with pytest.raises(PolicyDenied, match="DESTRUCTIVE capabilities require a justification"): + engine.evaluate( + _req("cap.d"), + _cap("cap.d", SafetyClass.DESTRUCTIVE), + p, + justification="short", + ) + + def test_destructive_denied_no_admin() -> None: p = Principal(principal_id="u1", roles=["writer"]) with pytest.raises(PolicyDenied, match="admin"): From bc75e42eb04db7094917715af9cfec51288951ed Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:40:57 +0000 Subject: [PATCH 14/19] fix: remove duplicate Budgets class, consolidate to firewall.budgets --- src/agent_kernel/firewall/budgets.py | 4 ++-- src/agent_kernel/firewall/transform.py | 11 +++-------- src/agent_kernel/models.py | 10 ---------- tests/test_models.py | 2 +- 4 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/agent_kernel/firewall/budgets.py b/src/agent_kernel/firewall/budgets.py index 1bcb343..9ad367f 100644 --- a/src/agent_kernel/firewall/budgets.py +++ b/src/agent_kernel/firewall/budgets.py @@ -1,7 +1,7 @@ """Budgets dataclass for the context firewall. -Re-exported from :mod:`agent_kernel.models` for convenience; also available -directly as ``agent_kernel.firewall.Budgets``. +Canonical definition of :class:`Budgets`. Re-exported via +``agent_kernel.firewall`` and the top-level ``agent_kernel`` package. """ from __future__ import annotations diff --git a/src/agent_kernel/firewall/transform.py b/src/agent_kernel/firewall/transform.py index 43e7d8b..48507df 100644 --- a/src/agent_kernel/firewall/transform.py +++ b/src/agent_kernel/firewall/transform.py @@ -7,21 +7,16 @@ from typing import Any from ..models import ( - Budgets, Frame, Handle, Provenance, RawResult, ResponseMode, ) -from .budgets import Budgets as FirewallBudgets +from .budgets import Budgets from .redaction import redact from .summarize import summarize -# Use the models.Budgets for the Frame; the firewall.Budgets is re-exported for -# back-compat but they are structurally identical. -_Budgets = Budgets - class Firewall: """Transforms :class:`RawResult` objects into LLM-safe :class:`Frame` objects. @@ -32,9 +27,9 @@ class Firewall: - Four response modes: ``summary``, ``table``, ``handle_only``, ``raw``. """ - def __init__(self, budgets: _Budgets | FirewallBudgets | None = None) -> None: + def __init__(self, budgets: Budgets | None = None) -> None: if budgets is None: - self._budgets: _Budgets | FirewallBudgets = _Budgets() + self._budgets = Budgets() else: self._budgets = budgets diff --git a/src/agent_kernel/models.py b/src/agent_kernel/models.py index aa1d921..d2f0878 100644 --- a/src/agent_kernel/models.py +++ b/src/agent_kernel/models.py @@ -173,16 +173,6 @@ class Provenance: action_id: str -@dataclass(slots=True) -class Budgets: - """Budget constraints for the context firewall.""" - - max_rows: int = 50 - max_fields: int = 20 - max_chars: int = 4000 - max_depth: int = 3 - - @dataclass(slots=True) class FieldSpec: """Describes a single field in a structured result.""" diff --git a/tests/test_models.py b/tests/test_models.py index cf5b024..eaee735 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -5,9 +5,9 @@ import datetime from agent_kernel.enums import SafetyClass, SensitivityTag +from agent_kernel.firewall.budgets import Budgets from agent_kernel.models import ( ActionTrace, - Budgets, Capability, CapabilityRequest, Frame, From ccb5dd656f49cdbcdc53eafae13878323e5f74a0 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:48:41 +0000 Subject: [PATCH 15/19] fix: tighten _PHONE_RE to require phone-like structure, add regex tests --- src/agent_kernel/firewall/redaction.py | 16 ++- tests/test_redaction.py | 188 +++++++++++++++++++++++++ 2 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 tests/test_redaction.py diff --git a/src/agent_kernel/firewall/redaction.py b/src/agent_kernel/firewall/redaction.py index dfc0591..999e240 100644 --- a/src/agent_kernel/firewall/redaction.py +++ b/src/agent_kernel/firewall/redaction.py @@ -22,7 +22,21 @@ ) _EMAIL_RE = re.compile(r"[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+") -_PHONE_RE = re.compile(r"\+?[\d\s\-().]{7,}") +_PHONE_RE = re.compile( + r""" + (? None: + assert _EMAIL_RE.search(text), f"Expected match for: {text}" + + @pytest.mark.parametrize( + "text", + [ + "plaintext", + "user@", + "@domain.com", + "user@@domain.com", + ], + ) + def test_rejects_non_emails(self, text: str) -> None: + assert not _EMAIL_RE.search(text), f"Unexpected match for: {text}" + + +# ── _PHONE_RE ────────────────────────────────────────────────────────────────── + + +class TestPhoneRegex: + """True-positive and true-negative tests for _PHONE_RE.""" + + @pytest.mark.parametrize( + "text", + [ + "+1-555-123-4567", + "(555) 123-4567", + "555-123-4567", + "555.123.4567", + "+44 20 7946 0958", + "(020) 7946-0958", + "123 456 7890", + "+1 800 555 0199", + ], + ) + def test_matches_phone_numbers(self, text: str) -> None: + assert _PHONE_RE.search(text), f"Expected match for: {text}" + + @pytest.mark.parametrize( + "text", + [ + "2026-03-04", + "100.00 - 200.00", + "v1.2.3.456", + "192.168.1.100", + "order-12345", + "1234567", + "(100)", + "3.14159", + "2026/03/04", + "ID: 9876543", + ], + ) + def test_rejects_non_phones(self, text: str) -> None: + assert not _PHONE_RE.search(text), f"Unexpected match for: {text}" + + +# ── _CARD_RE ────────────────────────────────────────────────────────────────── + + +class TestCardRegex: + """True-positive and true-negative tests for _CARD_RE.""" + + @pytest.mark.parametrize( + "text", + [ + "4111111111111111", + "4111 1111 1111 1111", + "4111-1111-1111-1111", + "5500000000000004", + ], + ) + def test_matches_card_numbers(self, text: str) -> None: + assert _CARD_RE.search(text), f"Expected match for: {text}" + + @pytest.mark.parametrize( + "text", + [ + "12345", + "abcdefghijklmnop", + "123-45-6789", + ], + ) + def test_rejects_non_cards(self, text: str) -> None: + assert not _CARD_RE.search(text), f"Unexpected match for: {text}" + + +# ── _SSN_RE ─────────────────────────────────────────────────────────────────── + + +class TestSSNRegex: + """True-positive and true-negative tests for _SSN_RE.""" + + @pytest.mark.parametrize( + "text", + [ + "123-45-6789", + "123 45 6789", + ], + ) + def test_matches_ssn(self, text: str) -> None: + assert _SSN_RE.search(text), f"Expected match for: {text}" + + @pytest.mark.parametrize( + "text", + [ + "123456789", + "12-345-6789", + "1234-56-789", + "abc-de-fghi", + ], + ) + def test_rejects_non_ssn(self, text: str) -> None: + assert not _SSN_RE.search(text), f"Unexpected match for: {text}" + + +# ── redact() integration ────────────────────────────────────────────────────── + + +class TestRedactFunction: + """Integration tests for the redact() function with pattern redaction.""" + + def test_phone_in_string_redacted(self) -> None: + data = "Call me at (555) 123-4567 please" + result, warnings = redact(data) + assert "(555) 123-4567" not in result + assert "[REDACTED]" in result + assert len(warnings) == 1 + + def test_date_not_redacted(self) -> None: + data = "Date: 2026-03-04" + result, warnings = redact(data) + assert result == data + assert not warnings + + def test_price_range_not_redacted(self) -> None: + data = "Price: 100.00 - 200.00" + result, warnings = redact(data) + assert result == data + assert not warnings + + def test_ip_address_not_redacted(self) -> None: + data = "Server: 192.168.1.100" + result, warnings = redact(data) + assert result == data + assert not warnings + + def test_email_in_dict_field_redacted(self) -> None: + data = {"email": "user@example.com", "name": "Alice"} + result, warnings = redact(data) + assert result["email"] == "[REDACTED]" + assert result["name"] == "Alice" + + def test_ssn_in_string_redacted(self) -> None: + data = "SSN: 123-45-6789" + result, warnings = redact(data) + assert "123-45-6789" not in result + assert "[REDACTED]" in result From c32598563581fe3662cadd2743989b8286d7bd2f Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:52:05 +0000 Subject: [PATCH 16/19] fix: bound HandleStore with max_entries cap and periodic auto-eviction --- src/agent_kernel/handles.py | 32 ++++++++++++++++++++++++++++-- tests/test_handles.py | 39 +++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/agent_kernel/handles.py b/src/agent_kernel/handles.py index 57ba367..2f4b8e9 100644 --- a/src/agent_kernel/handles.py +++ b/src/agent_kernel/handles.py @@ -13,11 +13,23 @@ class HandleStore: """Stores full capability results by handle ID with TTL-based expiry. - Entries are evicted lazily (on access) or explicitly via :meth:`evict_expired`. + Entries are evicted lazily (on access), periodically during :meth:`store`, + or explicitly via :meth:`evict_expired`. A *max_entries* cap prevents + unbounded memory growth in long-lived processes — when the cap is exceeded + the oldest entries are dropped after expired ones are cleared. """ - def __init__(self, default_ttl_seconds: int = 3600) -> None: + _EVICT_INTERVAL: int = 128 # run evict_expired() every N store() calls + + def __init__( + self, + default_ttl_seconds: int = 3600, + *, + max_entries: int = 10_000, + ) -> None: self._default_ttl = default_ttl_seconds + self._max_entries = max_entries + self._store_count = 0 self._data: dict[str, Any] = {} self._meta: dict[str, Handle] = {} @@ -51,6 +63,22 @@ def store( ) self._data[handle.handle_id] = data self._meta[handle.handle_id] = handle + + # Periodic eviction of expired entries + self._store_count += 1 + if self._store_count % self._EVICT_INTERVAL == 0: + self.evict_expired() + + # Cap enforcement: evict oldest entries when over the limit + if len(self._meta) > self._max_entries: + self.evict_expired() # clear expired first + overflow = len(self._meta) - self._max_entries + if overflow > 0: + oldest = sorted(self._meta, key=lambda hid: self._meta[hid].created_at) + for hid in oldest[:overflow]: + self._data.pop(hid, None) + self._meta.pop(hid, None) + return handle # ── Retrieval ───────────────────────────────────────────────────────────── diff --git a/tests/test_handles.py b/tests/test_handles.py index 9e44348..12689ae 100644 --- a/tests/test_handles.py +++ b/tests/test_handles.py @@ -106,3 +106,42 @@ def test_expand_handle_not_found(store: HandleStore) -> None: ) with pytest.raises(HandleNotFound): store.expand(fake_handle, query={}) + + +# ── Bounded store ────────────────────────────────────────────────────────────── + + +def test_max_entries_evicts_oldest() -> None: + s = HandleStore(default_ttl_seconds=3600, max_entries=5) + handles = [s.store("cap.x", [i]) for i in range(7)] + # Only 5 should remain; the 2 oldest were evicted + assert len(s._meta) == 5 + # Oldest handles should be gone + with pytest.raises(HandleNotFound): + s.get(handles[0].handle_id) + with pytest.raises(HandleNotFound): + s.get(handles[1].handle_id) + # Newest should still be accessible + assert s.get(handles[6].handle_id) == [6] + + +def test_max_entries_prefers_expired_over_live() -> None: + s = HandleStore(default_ttl_seconds=3600, max_entries=3) + # Store 2 already-expired + 1 live + s.store("cap.x", ["expired1"], ttl_seconds=-1) + s.store("cap.x", ["expired2"], ttl_seconds=-1) + live = s.store("cap.x", ["live"], ttl_seconds=3600) + # Now add a 4th — should evict the 2 expired first, then no overflow + new = s.store("cap.x", ["new"], ttl_seconds=3600) + assert len(s._meta) == 2 + assert s.get(live.handle_id) == ["live"] + assert s.get(new.handle_id) == ["new"] + + +def test_periodic_eviction_on_store() -> None: + s = HandleStore(default_ttl_seconds=3600, max_entries=10_000) + # Fill with expired entries below the cap + for i in range(HandleStore._EVICT_INTERVAL): + s.store("cap.x", [i], ttl_seconds=-1) + # All expired entries should have been evicted at the interval boundary + assert len(s._meta) == 0 From e6f8c380affdce4af7e0fd43640550d4ce859b99 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 06:58:07 +0000 Subject: [PATCH 17/19] refactor: deduplicate get_token via grant_capability, bring kernel.py under 300 lines --- src/agent_kernel/kernel.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index 0e1b8bc..7555b2b 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -142,31 +142,14 @@ def get_token( ) -> CapabilityToken: """Like :meth:`grant_capability` but returns the token directly. - This is a convenience method for use in :meth:`invoke`. - - Args: - request: The capability request. - principal: The requesting principal. - justification: Free-text justification. - - Returns: - A signed :class:`CapabilityToken`. - - Raises: - PolicyDenied: If the policy engine rejects the request. - CapabilityNotFound: If the capability is not registered. + Convenience wrapper for callers that don't need the full + :class:`CapabilityGrant`. Delegates entirely to + :meth:`grant_capability`; see its docstring for parameter and + exception details. """ - capability = self._registry.get(request.capability_id) - decision = self._policy.evaluate( - request, capability, principal, justification=justification - ) - audit_id = str(uuid.uuid4()) - return self._token_provider.issue( - capability.capability_id, - principal.principal_id, - constraints=decision.constraints, - audit_id=audit_id, - ) + return self.grant_capability( + request, principal, justification=justification + ).token async def invoke( self, From 22f61a8b0a7788c819d36ffa10033f44d2b47a5c Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 07:01:43 +0000 Subject: [PATCH 18/19] fix: add threading.Lock to _get_secret, fix test_dev_secret state leakage --- src/agent_kernel/tokens.py | 23 +++++++++++++++-------- tests/test_tokens.py | 18 +++++++++++------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/agent_kernel/tokens.py b/src/agent_kernel/tokens.py index 5520ebb..babbace 100644 --- a/src/agent_kernel/tokens.py +++ b/src/agent_kernel/tokens.py @@ -9,6 +9,7 @@ import logging import os import secrets +import threading import uuid from dataclasses import dataclass, field from typing import Any, Protocol @@ -18,21 +19,27 @@ logger = logging.getLogger(__name__) _DEV_SECRET: str | None = None +_DEV_SECRET_LOCK = threading.Lock() def _get_secret() -> str: - """Return the HMAC secret from the environment or generate a dev fallback.""" + """Return the HMAC secret from the environment or generate a dev fallback. + + Thread-safe: a :data:`threading.Lock` ensures only one thread generates + the fallback secret. + """ global _DEV_SECRET secret = os.environ.get("AGENT_KERNEL_SECRET") if secret: return secret - if _DEV_SECRET is None: - _DEV_SECRET = secrets.token_hex(32) - logger.warning( - "AGENT_KERNEL_SECRET is not set. " - "Using a random development secret — tokens will not survive restarts. " - "Set AGENT_KERNEL_SECRET in production." - ) + with _DEV_SECRET_LOCK: + if _DEV_SECRET is None: + _DEV_SECRET = secrets.token_hex(32) + logger.warning( + "AGENT_KERNEL_SECRET is not set. " + "Using a random development secret — tokens will not survive restarts. " + "Set AGENT_KERNEL_SECRET in production." + ) return _DEV_SECRET diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 3b0fd97..5c786a0 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -94,12 +94,16 @@ def test_dev_secret_warning(caplog: pytest.LogCaptureFixture) -> None: """A provider with no secret should generate a warning.""" import logging - # Reset dev secret to force warning import agent_kernel.tokens as tok_mod - tok_mod._DEV_SECRET = None - provider_no_secret = HMACTokenProvider(secret=None) - with caplog.at_level(logging.WARNING, logger="agent_kernel.tokens"): - token = provider_no_secret.issue("cap.x", "user-1") - assert "AGENT_KERNEL_SECRET" in caplog.text - assert token.signature != "" + # Save and restore _DEV_SECRET to avoid leaking state to other tests + original = tok_mod._DEV_SECRET + try: + tok_mod._DEV_SECRET = None + provider_no_secret = HMACTokenProvider(secret=None) + with caplog.at_level(logging.WARNING, logger="agent_kernel.tokens"): + token = provider_no_secret.issue("cap.x", "user-1") + assert "AGENT_KERNEL_SECRET" in caplog.text + assert token.signature != "" + finally: + tok_mod._DEV_SECRET = original From 692e06f6abf0742f79dd228533d3f40555bf4f22 Mon Sep 17 00:00:00 2001 From: dgenio Date: Wed, 4 Mar 2026 07:06:20 +0000 Subject: [PATCH 19/19] style: fix ruff format for get_token return statement --- src/agent_kernel/kernel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/agent_kernel/kernel.py b/src/agent_kernel/kernel.py index 7555b2b..43334d7 100644 --- a/src/agent_kernel/kernel.py +++ b/src/agent_kernel/kernel.py @@ -147,9 +147,7 @@ def get_token( :meth:`grant_capability`; see its docstring for parameter and exception details. """ - return self.grant_capability( - request, principal, justification=justification - ).token + return self.grant_capability(request, principal, justification=justification).token async def invoke( self,