diff --git a/.claude/settings.local.json b/.claude/settings.local.json index d646372..652dabc 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -1,117 +1,12 @@ { "permissions": { "allow": [ - "Bash(wc:*)", - "Bash(git:*)", - "Bash(black:*)", - "Bash(ruff check:*)", - "Bash(kill %1:*)", - "Bash(gh run:*)", - "WebFetch(domain:pypi.org)", - "Bash(brew install:*)", - "Bash(gh:*)", - "Bash(mypy:*)", - "WebFetch(domain:github.com)", - "WebFetch(domain:raw.githubusercontent.com)", - "Bash(pip install:*)", - "Bash(black .:*)", - "Bash(grep -r \"from langfuse\" /mnt/c/Users/elish/code/shekel/shekel --include=\"*.py\" 2>/dev/null | head -20)", - "Bash(python3 -c \"\nimport tomllib\nwith open\\('pyproject.toml', 'rb'\\) as f:\n data = tomllib.load\\(f\\)\n \nprint\\('all extra:', data['project']['optional-dependencies']['all']\\)\nprint\\('all-models extra:', data['project']['optional-dependencies']['all-models']\\)\nprint\\('langfuse extra:', data['project']['optional-dependencies']['langfuse']\\)\n\")", - "Bash(python -m pytest tests/integrations/ -v 2>&1 | head -150)", - "Bash(python3 -m pytest tests/integrations/ -v --tb=short 2>&1 | head -200)", - "Bash(python3 -m py_compile shekel/integrations/langfuse.py shekel/integrations/base.py shekel/integrations/registry.py shekel/integrations/async_queue.py 2>&1)", - "Bash(grep -r \"import langfuse\\\\|from langfuse\" /mnt/c/Users/elish/code/shekel/tests --include=\"*.py\" 2>/dev/null | head -5)", - "Bash(grep \"\\\\.utilization\" /mnt/c/Users/elish/code/shekel/examples/langfuse/*.py)", - "Bash(git diff:*)", - "Bash(grep -A 1 \"@property\" /mnt/c/Users/elish/code/shekel/shekel/_budget.py | grep \"def \" | sed 's/.*def //' | sed 's/\\(.*$//')", - "Bash(python -m pytest tests/test_nested_auto_capping.py tests/integrations/test_core_integration.py tests/integrations/test_adapter_pattern.py tests/test_decorator.py -v 2>&1 | tail -40)", - "Bash(python3 -m pytest tests/test_nested_auto_capping.py tests/integrations/test_core_integration.py tests/integrations/test_adapter_pattern.py tests/test_decorator.py -v 2>&1 | tail -50)", - "Bash(uv run:*)", - "Bash(find . -name \"pytest\" -o -name \"python3*\" 2>/dev/null | head -5 && cat pyproject.toml | grep -A3 \"\\\\[tool.pytest\")", - "Bash(pip3 show:*)", - "Bash(pip3 install:*)", - "Bash(/usr/bin/python3 -m pip install --user black 2>&1 | tail -3)", - "Bash(/usr/bin/python3 -c \"import ensurepip; ensurepip.bootstrap\\(\\)\" 2>&1 && /usr/bin/python3 -m pip install --user black 2>&1 | tail -3)", - "Bash(apt-get install:*)", - "Bash(sudo apt-get:*)", - "Read(//mnt/c/Users/elish/AppData/Local/Programs/**)", - "Read(//mnt/c/Users/elish/AppData/Roaming/**)", - "Read(//mnt/c/Users/elish/**)", - "Bash(export PATH=\"$HOME/.pyenv/bin:$HOME/.pyenv/shims:$PATH\" && black --version)", - "Bash(export PATH=\"$HOME/.pyenv/bin:$HOME/.pyenv/shims:$PATH\" && pyenv shell 3.11.15 && pip install --user black && black --version)", - "Bash(~/.pyenv/versions/3.11.15/bin/pip install:*)", - "Bash(PYBIN=~/.pyenv/versions/3.11.15/bin && echo \"=== black ===\" && $PYBIN/black --check . 2>&1 | tail -3 && echo \"=== isort ===\" && $PYBIN/isort --check-only . 2>&1 | tail -5 && echo \"=== ruff ===\" && $PYBIN/ruff check . 2>&1 | tail -5 && echo \"=== mypy ===\" && $PYBIN/mypy shekel/ 2>&1 | tail -5)", - "Bash(BIN=~/.pyenv/versions/3.11.15/bin && echo \"=== black ===\" && /home/elish/.local/bin/black --check . 2>&1 | tail -3)", - "Bash(BIN=~/.pyenv/versions/3.11.15/bin && echo \"=== isort ===\" && $BIN/isort --check-only . 2>&1 | tail -5 && echo \"=== ruff ===\" && $BIN/ruff check . 2>&1 | tail -5 && echo \"=== mypy ===\" && $BIN/mypy shekel/ 2>&1 | tail -5)", - "Bash(~/.pyenv/versions/3.11.15/bin/isort --check-only .)", - "Read(//mnt/c/Users/elish/code/shekel/**)", - "Bash(~/.pyenv/versions/3.11.15/bin/ruff check:*)", - "Bash(~/.pyenv/versions/3.11.15/bin/mypy shekel/)", - "Bash(~/.pyenv/versions/3.11.15/bin/isort --check-only . 2>&1; echo \"EXIT: $?\")", - "Bash(~/.pyenv/versions/3.11.15/bin/isort examples/langfuse/complete_demo.py examples/langfuse/quickstart.py tests/test_nested_auto_capping.py tests/integrations/test_adapter_pattern.py tests/integrations/test_async_queue.py && echo \"done\")", - "Bash(/home/elish/.local/bin/black --line-length 100 examples/langfuse/complete_demo.py examples/langfuse/quickstart.py tests/test_nested_auto_capping.py tests/integrations/test_adapter_pattern.py tests/integrations/test_async_queue.py && echo \"done\")", - "Bash(/home/elish/.local/bin/black --line-length 100 examples/langfuse/complete_demo.py tests/integrations/test_core_integration.py tests/test_nested_auto_capping.py 2>&1)", - "Bash(/home/elish/.local/bin/black --check .)", - "Bash(/home/elish/.local/bin/black --check . 2>&1; echo \"BLACK EXIT: $?\")", - "Bash(/home/elish/.local/bin/black --line-length 100 tests/integrations/test_adapter_pattern.py tests/integrations/test_langfuse_circuit_break.py tests/integrations/test_langfuse_cost_streaming.py tests/integrations/test_langfuse_fallback.py tests/integrations/test_langfuse_nested_mapping.py 2>&1)", - "Bash(~/.pyenv/versions/3.11.15/bin/pytest tests/ -v --cov=shekel --cov-report=term-missing --cov-fail-under=90 2>&1)", - "Bash(npm list:*)", - "WebSearch", - "Bash(which npx:*)", - "Read(//usr/bin/**)", - "Bash(npm --version)", - "Bash(npm --version && npx --version)", - "Bash(npx bmad-method:*)", - "Bash(curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash 2>&1 | tail -5)", - "Bash(export NVM_DIR=\"$HOME/.nvm\" && . \"$NVM_DIR/nvm.sh\" && nvm install 22 2>&1 | tail -5 && node --version)", - "Read(//home/elish/.nvm/**)", - "Bash(source ~/.nvm/nvm.sh)", - "Bash(nvm install:*)", - "Bash(grep -r \"v0\\\\\\\\.2\\\\\\\\.[0-9]\" /mnt/c/Users/elish/code/shekel/docs/*.md /mnt/c/Users/elish/code/shekel/docs/**/*.md | grep -v \"CHANGELOG\\\\|changelog\" | head -20)", - "Bash(cat /home/elish/.claude/projects/-mnt-c-Users-elish-code-shekel/511ceb7f-1b10-477c-a628-4f32a473aafd/tool-results/toolu_012eRcqWe5kFZnaqkfoTQNcU.txt | jq -r '.[0].text' > /mnt/c/Users/elish/code/shekel/IMPLEMENTATION_PLAN.md)", - "Bash(pyenv local:*)", - "Bash(python -m pytest tests/providers/test_provider_test_base.py -v 2>&1)", - "Bash(python -m pytest tests/ --ignore=tests/providers -x -q 2>&1)", - "Bash(python -m pytest tests/providers/test_registry.py -v 2>&1 | head -40)", - "Bash(python -m pytest tests/providers/ -v 2>&1 | tail -60)", - "Bash(python -m pytest --tb=short -q 2>&1 | tail -20)", - "Bash(python -m pytest --tb=short -q 2>&1 | tail -30)", - "Bash(python -m pytest --tb=short -q 2>&1 | tail -5)", - "Bash(isort --check-only shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1)", - "Bash(isort shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1 && ruff check shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1)", - "Bash(mkdocs build:*)", - "Bash(black shekel/providers/ && python -m pytest tests/providers/ -q 2>&1 | tail -5)", - "Bash(echo \"=== BLACK ===\" && black --check shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1 | tail -3 && echo \"=== ISORT ===\" && isort --check-only shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1 | tail -3 && echo \"=== RUFF ===\" && ruff check shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1 && echo \"=== MYPY ===\" && mypy shekel/ 2>&1 | tail -3)", - "Bash(isort shekel/providers/ tests/providers/ examples/cohere_adapter_template.py && black shekel/providers/ tests/providers/ examples/cohere_adapter_template.py && ruff check shekel/providers/ tests/providers/ examples/cohere_adapter_template.py 2>&1)", - "Bash(python -m pytest tests/test_call_limits.py -xvs 2>&1 | head -200)", - "Bash(python -m pytest tests/test_fallback.py -xvs 2>&1 | head -100)", - "Bash(python -m pytest tests/test_fallback.py -xvs 2>&1 | head -200)", - "Bash(python -m pytest tests/test_fallback.py -xvs 2>&1 | tail -50)", - "Bash(python -m pytest tests/test_decorator.py -xvs 2>&1 | tail -50)", - "Bash(python -m pytest tests/test_summary.py -xvs 2>&1 | tail -50)", - "Bash(python -m pytest tests/test_session_budget.py -xvs 2>&1 | tail -50)", - "Bash(python -m pytest tests/integrations/test_ollama_integration.py -xvs 2>&1 | tail -100)", - "Bash(python -m pytest tests/test_fallback.py tests/test_decorator.py tests/test_summary.py tests/test_session_budget.py tests/integrations/test_ollama_integration.py -v 2>&1 | tail -100)", - "Bash(python -m pytest tests/test_fallback.py tests/test_decorator.py tests/test_summary.py tests/test_session_budget.py tests/integrations/test_ollama_integration.py --tb=short 2>&1 | grep -E \"\\(PASSED|FAILED|ERROR|passed|failed\\)\" | tail -5)", - "Bash(python -m pytest tests/ -q --tb=no 2>&1 | tail -20)", - "Bash(python -m pytest tests/ -q --tb=line 2>&1 | tail -30)", - "Bash(python -m pytest tests/ -q --tb=no 2>&1 | tail -30)", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_model_rewritten_in_kwargs -xvs 2>&1 | tail -50)", - "Bash(python -m pytest tests/test_fallback.py -xvs 2>&1 | head -150)", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_model_rewritten_in_kwargs -xvs 2>&1 | grep -A 20 \"AssertionError\")", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_model_rewritten_in_kwargs -xvs 2>&1 | tail -30)", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_model_rewritten_in_kwargs tests/test_fallback.py::test_fallback_spent_tracks_separately -xvs 2>&1 | tail -20)", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_model_rewritten_in_kwargs -xvs 2>&1 | tail -20)", - "Bash(python -m pytest tests/ -q --tb=no 2>&1 | tail -10)", - "Bash(python -m pytest tests/test_fallback.py::test_fallback_small_call_within_budget -xvs 2>&1 | grep -A 5 \"assert\")", - "Bash(GIT_EDITOR=true git merge --continue)", - "Bash(python -m ruff check . 2>&1)", - "Bash(python -m mypy shekel/ 2>&1)", - "Bash(python -m black --check . 2>&1)", - "Bash(python -m isort --check-only . 2>&1)", - "Bash(python -m pytest --cov=shekel --cov-report=term-missing 2>&1)", - "Bash(python -c \"\nimport yaml, os\nwith open\\('/mnt/c/Users/elish/code/shekel/mkdocs.yml'\\) as f:\n cfg = yaml.safe_load\\(f\\)\n\ndef extract_paths\\(nav\\):\n for item in nav:\n if isinstance\\(item, dict\\):\n for k, v in item.items\\(\\):\n if isinstance\\(v, str\\):\n yield v\n elif isinstance\\(v, list\\):\n yield from extract_paths\\(v\\)\n\nmissing = []\nfor path in extract_paths\\(cfg['nav']\\):\n full = f'/mnt/c/Users/elish/code/shekel/docs/{path}'\n if not os.path.exists\\(full\\):\n missing.append\\(path\\)\n\nif missing:\n print\\('MISSING:', missing\\)\nelse:\n print\\('All nav files exist.'\\)\n\")", - "Bash(python -m mypy shekel/ --verbose 2>&1 | grep \"_pytest\" | head -10)" + "Bash(*)", + "Read(*)", + "Edit(*)", + "Write(*)", + "WebFetch(*)", + "WebSearch(*)" ] } } diff --git a/.python-version b/.python-version deleted file mode 100644 index 28d9a01..0000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.12.13 diff --git a/CHANGELOG.md b/CHANGELOG.md index c488ee8..ed9b87e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,75 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.1] - 2026-03-18 + +### Added + +- **`ShekelRuntime`** (`shekel/_runtime.py`) — framework detection and adapter wiring scaffold; called automatically at `budget.__enter__()` / `__exit__()` (and async variants) + - `ShekelRuntime.register(AdapterClass)` — class-level registry for framework adapters; adapters are probed once at budget open and released at budget close + - `probe()` — activates all registered adapters; silently skips adapters that raise `ImportError` (framework not installed) + - `release()` — deactivates adapters on budget exit; suppresses cleanup exceptions to avoid masking original errors + +- **`ComponentBudget`** (`shekel/_budget.py`) — lightweight dataclass for per-component cap tracking (`name`, `max_usd`, `_spent`) + +- **`Budget.node(name, max_usd)`** — register an explicit USD cap for a LangGraph node; returns `self` for chaining + +- **`Budget.agent(name, max_usd)`** — register an explicit USD cap for a named agent (CrewAI / OpenClaw); returns `self` for chaining + +- **`Budget.task(name, max_usd)`** — register an explicit USD cap for a named task (CrewAI); returns `self` for chaining + +- **`Budget.chain(name, max_usd)`** — register an explicit USD cap for a named LangChain chain; returns `self` for chaining; enforced by `LangChainRunnerAdapter` + +- **5 new exception subclasses** (`shekel/exceptions.py`), all inheriting from `BudgetExceededError`: + - `NodeBudgetExceededError(node_name, spent, limit)` — raised when a LangGraph node exceeds its cap + - `AgentBudgetExceededError(agent_name, spent, limit)` — raised when an agent exceeds its cap + - `TaskBudgetExceededError(task_name, spent, limit)` — raised when a task exceeds its cap + - `SessionBudgetExceededError(agent_name, spent, limit, window=None)` — raised when a rolling-window agent session exceeds its budget + - `ChainBudgetExceededError(chain_name, spent, limit)` — raised when a LangChain chain exceeds its cap + - `BudgetConfigMismatchError` — raised by `RedisBackend` when a budget name is reused with different limits/windows + +- **`budget.tree()` enhancement** — renders registered node/agent/task/chain component budgets below the children block; shows `[node]`, `[agent]`, `[task]`, `[chain]` labels with spent / limit / percentage + +- **`LangGraphAdapter`** (`shekel/providers/langgraph.py`) — transparent node-level circuit breaking for LangGraph; zero user code changes required + - Patches `StateGraph.add_node()` at `budget.__enter__()` so every node — sync and async — gets a pre-execution budget gate + - Pre-execution gate: raises `NodeBudgetExceededError` before the node body runs if the explicit node cap or parent budget is exhausted + - Post-execution attribution: spend delta credited to `ComponentBudget._spent` so `budget.tree()` shows per-node costs + - Reference-counted patch: nested budgets don't double-patch; restored when the last budget context closes + - Automatically skipped (silent `ImportError`) when `langgraph` is not installed + +- **`LangChainRunnerAdapter`** (`shekel/providers/langchain.py`) — transparent chain-level circuit breaking for LangChain; zero user code changes required + - Patches `Runnable._call_with_config`, `_acall_with_config`, and `RunnableSequence.invoke`/`ainvoke` + - Pre-execution gate: raises `ChainBudgetExceededError` before the chain body runs if the explicit chain cap or parent budget is exhausted + - Reference-counted patch: same nesting semantics as `LangGraphAdapter` + - Automatically skipped when `langchain_core` is not installed + +- **Multi-cap temporal budget spec** — `"$5/hr + 100 calls/hr"` string DSL for simultaneous USD + call-count caps with independent rolling windows + - `_parse_cap_spec()` — parses compound spec strings into a list of `(counter, limit, window_s)` triples + - `TemporalBudget` supports `usd`, `llm_calls`, `tool_calls`, and `tokens` counters simultaneously + - All-or-nothing atomicity: if any counter would exceed its limit, no counters are incremented + +- **`RedisBackend`** (`shekel/backends/redis.py`) — synchronous Redis-backed rolling-window budget backend for distributed enforcement + - Atomic all-or-nothing Lua script (single round-trip per call) + - Lazy connection with connection pool reuse + - Circuit breaker: stops calling Redis after N consecutive errors (configurable threshold + cooldown) + - Fail-closed (default) or fail-open (`on_unavailable="open"`) on backend unavailability + - `BudgetConfigMismatchError` when a budget name is reused with different limits/windows + - `on_backend_unavailable` adapter event + +- **`AsyncRedisBackend`** (`shekel/backends/redis.py`) — async version of `RedisBackend` for FastAPI, async LangGraph, and other async contexts; same semantics, all public methods are coroutines + +- **`on_backend_unavailable` adapter event** (`shekel/integrations/base.py`) — fires before raising `BudgetExceededError` (fail-closed) or allowing through (fail-open); payload: `budget_name`, `error` + +### Fixed + +- **Nested budget node/chain cap enforcement** — node caps registered on an outer `budget()` context are now correctly enforced inside inner nested budget contexts; `_find_node_cap()` and `_find_chain_cap()` walk the parent chain to locate the cap + +### Technical + +- **245 new TDD tests**: 45 in `tests/test_runtime.py`, 41 in `tests/test_langchain_wrappers.py`, 36 in `tests/test_langgraph_wrappers.py`, 81 in `tests/test_distributed_budgets.py` (unit) + 419-line Docker integration suite in `tests/integrations/test_redis_docker.py` +- `shekel/__version__` bumped to `0.3.1` +- `Budget.chain`, `ChainBudgetExceededError`, `BudgetConfigMismatchError`, `RedisBackend`, `AsyncRedisBackend` all exported in `shekel.__all__` + ## [0.2.9] - 2026-03-15 ### Added diff --git a/README.md b/README.md index 394a948..88dc878 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,82 @@ print(pipeline.tree()) Children auto-cap to the parent's remaining balance. `workflow.tree()` gives you a visual breakdown. +### LangGraph — per-node circuit breaking + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # reads REDIS_URL from env + +with budget("$5/hr + 100 calls/hr", name="api", backend=backend) as b: + b.node("fetch_data", max_usd=0.50) # hard cap per LangGraph node + b.node("summarize", max_usd=1.00) + + app.invoke({"query": "..."}) # NodeBudgetExceededError if node cap hit + +print(b.tree()) +# api: $0.84 / $5.00 +# [node] fetch_data: $0.12 / $0.50 (24%) +# [node] summarize: $0.72 / $1.00 (72%) +``` + +Shekel patches `StateGraph.add_node()` transparently — no graph changes needed. Every node gets a pre-execution budget gate. Caps roll up to the parent budget. + +### LangChain — per-chain circuit breaking + +```python +with budget(max_usd=5.00, name="pipeline") as b: + b.chain("retriever", max_usd=0.20) + b.chain("summarizer", max_usd=1.00) + + retriever_chain.invoke({"query": "..."}) # ChainBudgetExceededError if cap hit + summarizer_chain.invoke({"doc": "..."}) +``` + +Shekel patches `Runnable._call_with_config` and `RunnableSequence.invoke` — zero changes to your chains. `b.chain()` is chainable and composable with `b.node()`, `b.agent()`, and `b.task()`. + +### CrewAI — per-agent and per-task circuit breaking + +```python +from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError + +try: + with budget(max_usd=5.00, name="crew") as b: + b.agent(researcher.role, max_usd=2.00) # use agent.role directly + b.agent(writer.role, max_usd=1.00) + b.task(research_task.name, max_usd=1.50) # use task.name directly + b.task(write_task.name, max_usd=0.80) + crew.kickoff(inputs={"topic": "AI"}) # AgentBudgetExceededError or TaskBudgetExceededError if cap hit +except TaskBudgetExceededError as e: + print(f"Task '{e.task_name}' over budget") +except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' over budget") + +print(b.tree()) +# crew: $2.84 / $5.00 +# [agent] Senior Researcher: $1.92 / $2.00 (96.0%) +# [agent] Content Writer: $0.92 / $1.00 (92.0%) +# [task] research: $1.92 / $1.50 (128.0%) +# [task] write: $0.92 / $0.80 (115.0%) +``` + +Shekel patches `Agent.execute_task` transparently — zero crew or agent changes. Gate order: task cap → agent cap → global budget (most specific first). + +### Distributed budgets — enforce across multiple processes + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # REDIS_URL from env; fail-closed by default + +with budget("$5/hr + 100 calls/hr", name="api-tier", backend=backend) as b: + response = client.chat.completions.create(...) +# Atomic Lua-script enforcement — one Redis round-trip per call +# BudgetConfigMismatchError if the same name is reused with different limits +``` + +Works with `AsyncRedisBackend` for async workflows. Circuit breaker built in: configurable error threshold + cooldown before opening. Fail-open or fail-closed. + ### Rolling-window rate limits — `$5/hr` ```python @@ -175,6 +251,8 @@ async with api_budget: # BudgetExceededError carries retry_after and window_spent ``` +Multi-cap: `budget("$5/hr + 100 calls/hr")` — USD and call-count windows are independent. + ### Accumulate across sessions ```python @@ -278,13 +356,32 @@ budget( tool_prices={"web_search": 0.01}, # charge per tool fallback={"at_pct": 0.8, "model": "gpt-4o-mini"}, # switch instead of crash name="my-agent", # required for nesting + temporal budgets + backend=RedisBackend(), # distributed enforcement across processes ) -budget("$5/hr", name="api-tier") # temporal: rolling-window rate limit +budget("$5/hr + 100 calls/hr", name="api-tier") # multi-cap rolling-window +``` + +**Component caps** (all chainable): + +```python +b.node("fetch_data", max_usd=0.50) # LangGraph node cap +b.chain("retriever", max_usd=0.20) # LangChain chain cap +b.agent("researcher", max_usd=1.00) # CrewAI agent cap — AgentBudgetExceededError +b.task("summarize", max_usd=0.50) # CrewAI task cap — TaskBudgetExceededError ``` -`BudgetExceededError` → `spent`, `limit`, `model`, `retry_after` (temporal) -`ToolBudgetExceededError` → `tool_name`, `calls_used`, `calls_limit`, `framework` +**Exceptions:** + +| Exception | Fields | +|---|---| +| `BudgetExceededError` | `spent`, `limit`, `model`, `retry_after` | +| `NodeBudgetExceededError` | `node_name`, `spent`, `limit` | +| `AgentBudgetExceededError` | `agent_name`, `spent`, `limit` | +| `TaskBudgetExceededError` | `task_name`, `spent`, `limit` | +| `ChainBudgetExceededError` | `chain_name`, `spent`, `limit` | +| `ToolBudgetExceededError` | `tool_name`, `calls_used`, `calls_limit`, `framework` | +| `BudgetConfigMismatchError` | raised by Redis backend on config conflict | --- diff --git a/ai-metadata.json b/ai-metadata.json index 90ddf1e..e4ae664 100644 --- a/ai-metadata.json +++ b/ai-metadata.json @@ -22,10 +22,14 @@ "usage-limits", "quotas", "ai-governance", - "budget-enforcement" + "budget-enforcement", + "redis", + "distributed-budgets", + "circuit-breaker", + "rolling-window" ], - "purpose": "LLM budget control and cost governance for AI agents", - "description": "Open-source Python library for LLM budget control, token budgeting, and AI agent cost governance for OpenAI, Anthropic, LangChain, LangGraph, and modern LLMOps systems.", + "purpose": "LLM budget control and cost governance for AI agents — including distributed, per-node, and per-chain enforcement", + "description": "Open-source Python library for LLM budget control, token budgeting, and AI agent cost governance for OpenAI, Anthropic, LangChain, LangGraph, and modern LLMOps systems. Supports distributed Redis-backed enforcement, per-node LangGraph circuit breaking, per-chain LangChain enforcement, and multi-cap rolling-window rate limits.", "repository": "https://github.com/arieradle/shekel", "documentation": "https://arieradle.github.io/shekel/", "pypi": "https://pypi.org/project/shekel/", @@ -40,7 +44,11 @@ "AI API spend guardrails", "Multi-stage workflow budgeting", "Hierarchical cost tracking", - "Automatic model fallback" + "Automatic model fallback", + "Distributed multi-process budget enforcement", + "Per-node LangGraph circuit breaking", + "Per-chain LangChain circuit breaking", + "Rolling-window rate limiting per API tier or user" ], "integrations": [ "OpenAI", @@ -51,7 +59,8 @@ "CrewAI", "AutoGen", "LlamaIndex", - "Haystack" + "Haystack", + "Redis" ], "features": [ "Budget enforcement with hard caps", @@ -61,6 +70,11 @@ "Token budgeting", "Async and streaming support", "Framework agnostic", - "Zero-config setup" + "Zero-config setup", + "LangGraph node-level circuit breaking — per-node USD caps with automatic StateGraph patching", + "LangChain chain-level circuit breaking — per-chain USD caps with automatic Runnable patching", + "Distributed Redis enforcement — atomic Lua-script check-and-add across multiple processes", + "Multi-cap rolling-window specs — simultaneous USD + call-count caps (e.g. $5/hr + 100 calls/hr)", + "Circuit breaker pattern — configurable error threshold and cooldown on backend unavailability" ] } diff --git a/design-deleteme/distributed-budgets.md b/design-deleteme/distributed-budgets.md new file mode 100644 index 0000000..48aa1dc --- /dev/null +++ b/design-deleteme/distributed-budgets.md @@ -0,0 +1,293 @@ +# Design Decision: Distributed Budgets + +**Date:** 2026-03-17 +**Status:** Draft +**Branch:** feat/distributed-budgets + +--- + +## Context + +Shekel currently enforces budgets in-process via `InMemoryBackend`. This does not survive process restarts and cannot be shared across pods or services. Users need budgets that: + +- Persist across process restarts +- Are enforced consistently across multiple pods/services +- Have their own lifecycle independent of any single process + +--- + +## Decision + +Implement a **Redis-backed `TemporalBudgetBackend`** as the first distributed backend. The existing `TemporalBudgetBackend` protocol is upgraded to be **generic and counter-based** so future backends (gossip, Postgres, etc.) can plug in without protocol changes. + +--- + +## API + +### Two forms, never mixed. Mixing raises `ValueError`. + +```python +# SPEC STRING — richer form; each cap fully self-described with its own window +budget("$5/hr", name="api") +budget("100 calls/hr", name="api") +budget("$5/hr + 100 calls/day", name="api") # per-cap windows, valid +budget("$5/hr + 100 calls/hr + 20 tools/hr", name="api") + +# KWARGS — convenience shorthand; one window_seconds applies to all caps +budget(name="api", max_usd=5.0, window_seconds=3600) +budget(name="api", max_usd=5.0, max_llm_calls=100, window_seconds=3600) +# Need different windows per cap? → use spec string + +# MIXING → ValueError immediately +budget("$5/hr", name="api", max_llm_calls=100) # ← raises ValueError +``` + +| Need | Form | +|------|------| +| One cap | either form | +| Multi-cap, same window | either form | +| Multi-cap, different windows | spec string only | + +### Spec string identifiers + +| Spec string token | kwargs equivalent | Meaning | +|-------------------|-------------------|---------| +| `$N` or `N usd` | `max_usd` | USD spend | +| `N calls` | `max_llm_calls` | LLM calls | +| `N tools` | `max_tool_calls` | Tool calls | +| `N tokens` | `max_tokens` | Tokens (future) | + +Window units: `s`, `sec`, `min`, `hr`, `h`. (`day`/`week`/`month` out of scope v1.) + +--- + +## Backend Protocol (updated) + +Generic named-counter protocol. Backend is unaware of "USD" or "calls" — it manages named counters with limits. + +```python +class TemporalBudgetBackend(Protocol): + def check_and_add( + self, + budget_name: str, + amounts: dict[str, float], # {"usd": 0.03, "llm_calls": 1} + limits: dict[str, float | None], # {"usd": 5.0, "llm_calls": 100, "tool_calls": None} + windows: dict[str, float], # {"usd": 3600, "llm_calls": 86400} + ) -> tuple[bool, str | None]: # (allowed, first_exceeded_counter or None) + ... + + def get_state(self, budget_name: str) -> dict[str, float]: + # {"usd": 2.34, "llm_calls": 45} + ... + + def reset(self, budget_name: str) -> None: ... +``` + +- `None` limit = no cap; counter is still tracked for observability. +- Returns `(False, "usd")` or `(False, "llm_calls")` so caller raises the right exception. +- **Atomic all-or-nothing:** if any counter would exceed, none are incremented. +- **Future caps:** just new keys in `amounts`/`limits`/`windows` dicts — zero protocol changes. + +--- + +## Redis Backend + +### Connection + +```python +from shekel.backends.redis import RedisBackend, AsyncRedisBackend + +# Auto-discover from env (REDIS_URL): +backend = RedisBackend() + +# Explicit: +backend = RedisBackend(url="redis://user:pass@host:6379/0", tls=True) + +# Async contexts (FastAPI, LangGraph): +backend = AsyncRedisBackend() + +# Use: +with budget("$5/hr", name="api", backend=backend): + run_agent() +``` + +- **Lazy connect** on first `check_and_add` call. +- **Connection pool** — not one connection per call. +- `close()` / context manager for explicit lifecycle; also safe without — pool auto-closes on GC. + +### Key layout + +``` +shekel:tb:{name} + spec_hash → "abc123" (hash of {counter: (limit, window_s)}; mismatch detection) + usd:max → "5.0" + usd:window_s → "3600" + usd:window_start → "1710000000000" (Redis TIME in ms) + usd:spent → "2.34" + calls:max → "100" + calls:window_s → "86400" + calls:window_start → "1710000000000" + calls:spent → "45" + tools:max → "" (empty = no cap, still tracked) + tools:window_s → "3600" + tools:window_start → "1710000000000" + tools:spent → "3" +``` + +Per-cap `window_start` and `window_s` → counters reset independently. One Redis hash per budget name. + +### Lua semantics (per call) + +``` +1. For each counter in amounts: + a. Read counter:window_start, counter:window_s from hash + b. Call redis.call('TIME') → now_ms + c. If now_ms - window_start >= window_s * 1000: reset counter:spent = 0, counter:window_start = now_ms + d. If counter:max != "" AND spent + amount > max: return (0, counter_name) +2. All counters passed: HINCRBYFLOAT each counter:spent; PEXPIRE key (TTL = max window_s * 2000); return (1, nil) +``` + +One round-trip. One atomic operation. All-or-nothing. + +### Identity and mismatch detection + +- First writer stores spec hash (`{counter: (limit, window_s)}` dict hashed) alongside the counters. +- Every subsequent attaching process computes spec hash locally and compares. +- Mismatch → `BudgetConfigMismatchError("Budget 'api' already registered with different limits/windows")`. +- `reset()` → `DEL` key (clears spec hash and all counters; next writer registers fresh). + +### Failure behavior + +| Failure | Default behavior | Config | +|---------|-----------------|--------| +| Redis unreachable / timeout | **Fail-closed** → raise `BudgetExceededError("Backend unavailable — failing closed")` | `on_unavailable="open"` to allow through | +| Redis error (OOM, READONLY) | Same as unreachable | Same | +| 3 consecutive errors | Circuit breaker → stop calling Redis for 10s cooldown | Configurable | +| Process restart | Spend state in Redis → **no data loss** | — | +| Redis node restart | Needs AOF/replication (HA Redis or managed) | Deployment concern | + +Emits `on_backend_unavailable(budget_name, error)` observability event before raising or allowing. + +### Multi-region + +Use a globally replicated Redis-compatible store (e.g. **AWS MemoryDB multi-region**, ElastiCache Global Datastore). Shekel connects to a single endpoint (regional proxy to global primary). **Enforce on primary only** to avoid stale-replica over-spend. Cross-region write latency and RPO/RTO are deployment concerns, not Shekel concerns. + +### TTL and cleanup + +- Each counter's TTL = `counter_window_s * 2` ms — renewed on every write via `PEXPIRE`. +- If budget is abandoned, key auto-expires. +- `reset()` → `DEL` key entirely. + +--- + +## Enforcement + +- **Post-flight:** `check_and_add` fires after LLM call completes with actual cost. One call may overshoot just before enforcement — same semantics as `InMemoryBackend`, documented. +- **Raise on first exceeded counter.** Deterministic check order: `usd → llm_calls → tool_calls → custom`. +- Error carries counter name: `BudgetExceededError("Budget 'api' exceeded: calls (100/hr)")`. + +--- + +## Observability + +New event added to `ObservabilityAdapter` (default no-op): + +```python +def on_backend_unavailable(self, budget_name: str, error: Exception) -> None: ... +``` + +All existing events (`on_window_reset`, `on_budget_exceeded`, `on_cost_update`, etc.) fire as normal. + +--- + +## Testing with Redis + +### Recommended for GitHub CI: service container (free, no hosted API) + +**You do not need a paid or hosted Redis API for CI.** GitHub Actions can run a real Redis server next to the job using a **service container** — same Redis wire protocol as production, no external account. + +```yaml +# .github/workflows/… (excerpt) +jobs: + test: + runs-on: ubuntu-latest + services: + redis: + image: redis:7-alpine + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 5s + --health-timeout 3s + --health-retries 5 + env: + REDIS_URL: redis://127.0.0.1:6379/0 + steps: + - uses: actions/checkout@v4 + # … install Python, pytest … + - run: pytest tests/ -m redis # or run integration tests when REDIS_URL is set +``` + +- **Cost:** Included in GitHub Actions minutes for public repos; private repos use your Actions quota. No Redis vendor bill. +- **Pull speed:** `redis:7-alpine` is a small image (~tens of MB); first pull on a runner is often ~10–30s, with layer cache helping on subsequent jobs. Startup after pull is sub-second. If image pull is a concern, alternatives are pinning `redis:7-alpine` for cache stability or installing `redis-server` via `apt` on the runner (no Docker pull for Redis). +- **Isolation:** Fresh Redis per job run; tests can `FLUSHDB` or use key prefixes without polluting shared state. +- **Docs:** [Creating Redis service containers](https://docs.github.com/en/actions/guides/creating-redis-service-containers) (GitHub). + +**Local dev:** `docker run -p 6379:6379 redis:7-alpine` and set `REDIS_URL=redis://127.0.0.1:6379/0`. + +### Unit tests without Docker: fakeredis / mocks + +- **fakeredis:** In-process fake Redis; Lua support varies by version — validate Lua scripts against a real Redis in CI if fakeredis cannot run your script. +- **Mocks:** Stub `TemporalBudgetBackend` for pure unit tests of `TemporalBudget` wiring. + +### Optional: hosted “free tier” Redis (REST or TCP) + +Useful for **manual smoke tests**, **fork PRs** (if you avoid secrets in fork workflows), or when you cannot use service containers — **not required** for normal CI. + +| Provider | Free tier (typical) | Protocol | CI notes | +|----------|---------------------|----------|----------| +| **Upstash** | Limited commands/month, small storage | **TCP Redis** (TLS URL) and **HTTP REST** | Standard `redis-py` works with their Redis URL + TLS. REST API is a different client — Shekel should target **TCP Redis** unless you add an HTTP adapter. | +| **Redis Cloud (Redis Inc.)** | Small free instance | TCP | Create DB, put URL in **GitHub Actions secret** `REDIS_URL`; optional job `if: secrets.REDIS_URL != ''`. | +| **ElastiCache / MemoryDB** | No perpetual free tier | TCP | Pay-as-you-go; use for staging, not default CI. | + +**Caveats for hosted free tier in CI:** + +- **Rate / command limits** — parallel test jobs can exhaust free quotas. +- **Secrets** — fork PRs from untrusted contributors often **do not** receive repository secrets; service containers avoid that. +- **Shared DB** — multiple CI runs against one free DB can flake; prefer **ephemeral CI Redis** (service container) for PR pipelines. + +**Recommendation:** **Primary CI path = Redis service container + `REDIS_URL` to localhost.** Optional nightly or manual workflow with a secret `REDIS_URL` to a hosted free tier if you want to validate TLS / cloud-specific behavior. + +### Integration test layout (suggested) + +| Layer | Tool | When | +|-------|------|------| +| Protocol / Lua | Real Redis (Docker locally, service in CI) | Required before merge for Redis backend | +| Budget wiring | `InMemoryBackend` or fakeredis | Fast unit tests | +| Optional cloud smoke | Secret `REDIS_URL` to Upstash/Redis Cloud | Nightly or on-demand | + +--- + +## Out of Scope (v1) + +| Item | Notes | +|------|-------| +| Gossip / no-SPoF backend | Protocol is open; not implemented in v1 | +| K8s CRD for budget config | Possible future layer; not required for Redis v1 | +| Compound spec strings (`"$5/hr + 100 calls/day"` as compound) | Spec string supports per-cap windows via individual term parsing | +| `day` / `week` / `month` calendar windows | Not supported in v1 (existing limitation) | +| Per-cap window kwargs | Use spec string form instead | +| Multi-region split-quota | Handled by global Redis product choice | + +--- + +## Implementation Plan + +1. **Upgrade `TemporalBudgetBackend` protocol** — generic counters + per-cap windows. +2. **Upgrade `InMemoryBackend`** — match new protocol (also serves as test double). +3. **Upgrade `_parse_spec`** — support `"$5/hr + 100 calls/day"` multi-cap format. +4. **Upgrade `TemporalBudget`** — build `amounts`/`limits`/`windows` dicts from config; raise on form mixing. +5. **Implement `RedisBackend` + `AsyncRedisBackend`** — Lua script, lazy pool, spec hash, circuit breaker. +6. **Add `on_backend_unavailable` to `ObservabilityAdapter`** and emit in Redis backend. +7. **Tests (TDD):** unit (fakeredis / mocks), integration against real Redis (local Docker + GitHub Actions **Redis service container** — see **Testing with Redis**), fail-closed/open, mismatch detection, multi-cap enforcement, window reset. diff --git a/docs/api-reference.md b/docs/api-reference.md index 3ffb660..b67af6b 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -243,11 +243,207 @@ print(w.tree()) # analysis: $9.30 / $10.00 (direct: $9.30) ``` +Also renders registered component budgets (nodes, agents, tasks). LangGraph node spend is tracked automatically; agent/task spend requires future framework adapters. + +```python +with budget(max_usd=10, name="workflow") as b: + b.node("fetch", max_usd=0.50) + b.node("summarize", max_usd=1.00) + run_langgraph_workflow() + +print(b.tree()) +# workflow: $0.84 / $10.00 (direct: $0.00) +# [node] fetch: $0.12 / $0.50 (24.0%) +# [node] summarize: $0.72 / $1.00 (72.0%) +``` + **Returns:** Multi-line string with indented tree structure showing: - Budget name and hierarchy - Total spend / limit - Direct spend (excluding children) - `[ACTIVE]` marker for currently active children +- `[node]`, `[agent]`, `[task]` component budget lines with spend/limit/percentage + +#### `node(name, max_usd)` + +Register an explicit USD cap for a named LangGraph node. Returns `self` for chaining. + +The cap is enforced by `LangGraphAdapter` — `NodeBudgetExceededError` is raised before the node body executes when the cap is reached. Spend is attributed to `ComponentBudget._spent` and visible in `budget.tree()`. + +```python +with budget(max_usd=10.00) as b: + b.node("fetch_data", max_usd=0.50).node("summarize", max_usd=1.00) + + graph = StateGraph(State) + graph.add_node("fetch_data", fetch_fn) + graph.add_node("summarize", summarize_fn) + app = graph.compile() + app.invoke(state) +``` + +**Parameters:** +- `name` — node name (must match the name passed to `StateGraph.add_node()`) +- `max_usd` — USD cap; must be positive + +**Raises:** `ValueError` if `max_usd <= 0` + +#### `agent(name, max_usd)` + +Register an explicit USD cap for a named CrewAI agent. Returns `self` for chaining. + +Enforced by `CrewAIExecutionAdapter` — `AgentBudgetExceededError` is raised **before** `Agent.execute_task` runs when the cap is exhausted. Use `agent.role` as the key to eliminate string mismatch risk. Spend is attributed to `ComponentBudget._spent` and visible in `budget.tree()`. + +```python +with budget(max_usd=10.00) as b: + b.agent(researcher.role, max_usd=2.00).agent(writer.role, max_usd=1.50) + crew.kickoff(inputs={"topic": "AI"}) +``` + +**Parameters:** +- `name` — agent name (use `agent.role` directly) +- `max_usd` — USD cap; must be positive + +**Raises:** `ValueError` if `max_usd <= 0` + +#### `task(name, max_usd)` + +Register an explicit USD cap for a named CrewAI task. Returns `self` for chaining. + +Enforced by `CrewAIExecutionAdapter` — `TaskBudgetExceededError` is raised **before** `Agent.execute_task` runs when the cap is exhausted. Use `task.name` as the key directly. Gate order: task cap is checked before agent cap (most specific first). Spend is attributed independently to both the task and the executing agent. + +```python +with budget(max_usd=10.00) as b: + b.task(research_task.name, max_usd=1.50).task(write_task.name, max_usd=0.80) + crew.kickoff(inputs={"topic": "AI"}) +``` + +**Parameters:** +- `name` — task name (use `task.name` directly) +- `max_usd` — USD cap; must be positive + +**Raises:** `ValueError` if `max_usd <= 0` + +#### `chain(name, max_usd)` + +Register an explicit USD cap for a named LangChain chain. Returns `self` for chaining. + +Enforced by `LangChainRunnerAdapter` — `ChainBudgetExceededError` is raised before the chain body executes when the cap is reached. Spend is attributed to `ComponentBudget._spent` and visible in `budget.tree()`. + +```python +with budget(max_usd=10.00) as b: + b.chain("retriever", max_usd=0.20).chain("summarizer", max_usd=1.00) + chain.invoke({"query": "..."}) +``` + +**Parameters:** +- `name` — chain name (must match the `run_name` or object name passed to `add_node`/invoked directly) +- `max_usd` — USD cap; must be positive + +**Raises:** `ValueError` if `max_usd <= 0` + +--- + +## `TemporalBudget` (rolling-window budgets) + +Created via the `budget()` factory when a spec string or `window_seconds` is provided. + +### Temporal factory forms + +```python +# Spec string (per-cap windows) +with budget("$5/hr", name="api") as b: ... +with budget("$5/hr + 100 calls/hr", name="api") as b: ... + +# Kwargs (single shared window) +with budget(max_usd=5.0, window_seconds=3600, name="api") as b: ... +with budget(max_usd=5.0, max_llm_calls=100, window_seconds=3600, name="api") as b: ... +``` + +`name=` is required for `TemporalBudget`. + +### Supported counters in multi-cap specs + +| Token | Counter | Example | +|---|---|---| +| `$N` or `N usd` | `usd` | `$5/hr` | +| `N calls` | `llm_calls` | `100 calls/hr` | +| `N tools` | `tool_calls` | `20 tools/hr` | +| `N tokens` | `tokens` | `50000 tokens/hr` | + +### Using a custom backend + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend(url="redis://localhost:6379/0") + +with budget("$5/hr", name="api", backend=backend) as b: + run_agent() +``` + +--- + +## `RedisBackend` + +Synchronous Redis-backed rolling-window budget backend for distributed enforcement. + +### Constructor + +```python +RedisBackend( + url: str | None = None, # defaults to REDIS_URL env var + tls: bool = False, + on_unavailable: str = "closed", # "closed" | "open" + circuit_breaker_threshold: int = 3, + circuit_breaker_cooldown: float = 10.0, +) +``` + +| Parameter | Default | Description | +|---|---|---| +| `url` | `REDIS_URL` env | Redis connection URL | +| `tls` | `False` | Force TLS (`ssl=True`) | +| `on_unavailable` | `"closed"` | `"closed"` raises `BudgetExceededError`; `"open"` allows through | +| `circuit_breaker_threshold` | `3` | Consecutive errors before circuit opens | +| `circuit_breaker_cooldown` | `10.0` | Seconds before retrying after circuit opens | + +### Example + +```python +from shekel import budget +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # reads REDIS_URL from env + +with budget("$5/hr + 100 calls/hr", name="api-tier", backend=backend) as b: + run_agent() +``` + +### Methods + +- `check_and_add(budget_name, amounts, limits, windows)` — atomically check + increment counters +- `get_state(budget_name)` — return `{counter: spent}` for all counters +- `reset(budget_name)` — delete the Redis hash for `budget_name` +- `close()` — close the Redis connection + +**Raises:** `BudgetConfigMismatchError` if `budget_name` is already registered with different limits or windows. + +--- + +## `AsyncRedisBackend` + +Async version of `RedisBackend`. All public methods are coroutines. Suitable for FastAPI, async LangGraph, and other async contexts. + +```python +from shekel.backends.redis import AsyncRedisBackend + +backend = AsyncRedisBackend() + +async with budget("$5/hr", name="api", backend=backend) as b: + await run_async_agent() +``` + +Constructor and parameters are identical to `RedisBackend`. --- @@ -343,6 +539,122 @@ except BudgetExceededError as e: --- +## `NodeBudgetExceededError` + +Raised when a LangGraph node exceeds its registered USD cap. Subclass of `BudgetExceededError`. + +### Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `node_name` | `str` | Name of the node that exceeded its budget. | +| `spent` | `float` | Total USD spent when the cap was hit. | +| `limit` | `float` | The configured `max_usd` for this node. | + +```python +from shekel import budget, NodeBudgetExceededError, BudgetExceededError + +try: + with budget(max_usd=10.00) as b: + b.node("fetch", max_usd=0.10) + run_fetch_node() +except NodeBudgetExceededError as e: + print(f"Node '{e.node_name}' exceeded ${e.limit:.2f}") +except BudgetExceededError: + # catches all budget errors including NodeBudgetExceededError + ... +``` + +--- + +## `AgentBudgetExceededError` + +Raised when an agent exceeds its registered USD cap. Subclass of `BudgetExceededError`. + +### Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `agent_name` | `str` | Name of the agent that exceeded its budget. | +| `spent` | `float` | Total USD spent when the cap was hit. | +| `limit` | `float` | The configured `max_usd` for this agent. | + +--- + +## `TaskBudgetExceededError` + +Raised when a task exceeds its registered USD cap. Subclass of `BudgetExceededError`. + +### Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `task_name` | `str` | Name of the task that exceeded its budget. | +| `spent` | `float` | Total USD spent when the cap was hit. | +| `limit` | `float` | The configured `max_usd` for this task. | + +--- + +## `SessionBudgetExceededError` + +Raised when an always-on agent session exceeds its rolling-window budget. Subclass of `BudgetExceededError`. + +### Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `agent_name` | `str` | Name of the agent session that exceeded its budget. | +| `spent` | `float` | Total USD spent when the cap was hit. | +| `limit` | `float` | The configured session budget. | +| `window` | `float \| None` | Rolling window duration in seconds, or `None`. | + +--- + +## `ChainBudgetExceededError` + +Raised when a LangChain chain exceeds its registered USD cap. Subclass of `BudgetExceededError`. + +### Attributes + +| Attribute | Type | Description | +|-----------|------|-------------| +| `chain_name` | `str` | Name of the chain that exceeded its budget. | +| `spent` | `float` | Total USD spent when the cap was hit. | +| `limit` | `float` | The configured `max_usd` for this chain. | + +```python +from shekel import budget, ChainBudgetExceededError, BudgetExceededError + +try: + with budget(max_usd=10.00) as b: + b.chain("retriever", max_usd=0.20) + chain.invoke({"query": "..."}) +except ChainBudgetExceededError as e: + print(f"Chain '{e.chain_name}' exceeded ${e.limit:.2f}") +``` + +--- + +## `BudgetConfigMismatchError` + +Raised by `RedisBackend` / `AsyncRedisBackend` when a budget name is already registered with different limits or windows. Subclass of `BudgetExceededError`. + +```python +from shekel.exceptions import BudgetConfigMismatchError + +try: + with budget("$5/hr", name="api", backend=backend): + run_agent() +except BudgetConfigMismatchError: + # Budget "api" was previously registered with different caps. + # Call backend.reset("api") to clear existing state. + backend.reset("api") +``` + +**To resolve:** call `backend.reset(budget_name)` to delete the existing Redis state, then retry. + +--- + ## Type Signatures For type checking with mypy, pyright, etc: diff --git a/docs/changelog.md b/docs/changelog.md index 0591c68..7ea87f5 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,73 @@ All notable changes to this project are documented here. For detailed information, see [CHANGELOG.md](https://github.com/arieradle/shekel/blob/main/CHANGELOG.md) on GitHub. +## [0.3.1] {#031} + +### Hierarchical Budget Enforcement — LangGraph + LangChain circuit breaking + Distributed Budgets + +Per-node, per-chain, per-agent, and per-task USD caps with automatic LangGraph and LangChain instrumentation. Distributed enforcement via Redis for multi-process deployments. Zero code changes required — open a `budget()` context and run. + +```python +from shekel.backends.redis import RedisBackend + +backend = RedisBackend() # reads REDIS_URL from env + +with budget("$5/hr + 100 calls/hr", name="api", backend=backend) as b: + b.node("fetch_data", max_usd=0.50) + b.node("summarize", max_usd=1.00) + b.chain("retriever", max_usd=0.20) + + app.invoke({"query": "..."}) + +print(b.tree()) +# api: $0.84 / $5.00 (direct: $0.00) +# [node] fetch_data: $0.12 / $0.50 (24.0%) +# [node] summarize: $0.72 / $1.00 (72.0%) +# [chain] retriever: $0.00 / $0.20 (0.0%) +``` + +**LangGraph adapter** (`shekel/providers/langgraph.py`): +- Patches `StateGraph.add_node()` transparently — every node gets a pre-execution budget gate, no graph changes needed +- `NodeBudgetExceededError` raised *before* the node body runs when an explicit cap or the parent budget is exhausted +- Per-node spend attributed to `ComponentBudget._spent` → visible in `budget.tree()` +- Full async node support; auto-skipped when `langgraph` is not installed + +**LangChain adapter** (`shekel/providers/langchain.py`): +- Patches `Runnable._call_with_config`, `_acall_with_config`, and `RunnableSequence.invoke`/`ainvoke` +- `ChainBudgetExceededError` raised before chain body runs when cap or parent budget is exhausted +- Same reference-counting and nesting semantics as `LangGraphAdapter` +- Auto-skipped when `langchain_core` is not installed + +**Distributed budgets** (`shekel/backends/redis.py`): +- `RedisBackend` / `AsyncRedisBackend` — atomic Lua-script enforcement (one round-trip) +- Circuit breaker: configurable error threshold + cooldown before opening +- Fail-closed (default) or fail-open on backend unavailability +- `BudgetConfigMismatchError` when a budget name is reused with different limits/windows + +**Multi-cap temporal spec**: +- `budget("$5/hr + 100 calls/hr", name="api")` — simultaneous USD + call-count caps with independent rolling windows +- Supported counters: `usd`, `llm_calls`, `tool_calls`, `tokens` + +**API** (`Budget` methods, all chainable): +- `b.node(name, max_usd)` — explicit cap for a LangGraph node +- `b.chain(name, max_usd)` — explicit cap for a LangChain chain +- `b.agent(name, max_usd)` — explicit cap for a CrewAI / OpenClaw agent *(enforcement in future release)* +- `b.task(name, max_usd)` — explicit cap for a CrewAI task *(enforcement in future release)* + +**Exception hierarchy** (all subclass `BudgetExceededError`): +- `NodeBudgetExceededError` — `node_name`, `spent`, `limit` +- `ChainBudgetExceededError` — `chain_name`, `spent`, `limit` +- `AgentBudgetExceededError` — `agent_name`, `spent`, `limit` +- `TaskBudgetExceededError` — `task_name`, `spent`, `limit` +- `SessionBudgetExceededError` — `agent_name`, `spent`, `limit`, `window` +- `BudgetConfigMismatchError` — raised by Redis backend on config conflict + +**Fixed:** Node and chain caps registered on an outer `budget()` are now correctly enforced inside inner nested budget contexts. + +[Full CHANGELOG →](https://github.com/arieradle/shekel/blob/main/CHANGELOG.md#031) + +--- + ## [0.2.9] {#029} ### 🖥️ CLI Budget Enforcement — `shekel run` diff --git a/docs/index.md b/docs/index.md index a3f9856..40845c8 100644 --- a/docs/index.md +++ b/docs/index.md @@ -262,7 +262,70 @@ print(f"Remaining: ${b.remaining:.4f}") --- -## What's New in v0.2.9 +## What's New in v0.3.1 + +
+ +- :material-graph:{ .lg .middle } **LangGraph Node-Level Circuit Breaking** + + --- + + Per-node USD caps enforced automatically. `NodeBudgetExceededError` raised before the node body runs. Zero graph changes required. + + ```python + with budget(max_usd=10.00) as b: + b.node("fetch", max_usd=0.50) + b.node("summarize", max_usd=1.00) + + graph = StateGraph(State) + graph.add_node("fetch", fetch_fn) + graph.add_node("summarize", summarize_fn) + app = graph.compile() + app.invoke(state) + + print(b.tree()) + # [node] fetch: $0.12 / $0.50 (24.0%) + # [node] summarize: $0.72 / $1.00 (72.0%) + ``` + +- :material-robot:{ .lg .middle } **CrewAI Agent/Task Circuit Breaking** + + --- + + Per-agent and per-task USD caps enforced automatically. `AgentBudgetExceededError` / `TaskBudgetExceededError` raised before the agent executes. Zero crew changes required. + + ```python + with budget(max_usd=5.00) as b: + b.agent(researcher.role, max_usd=2.00) + b.task(research_task.name, max_usd=1.50) + crew.kickoff(inputs={"topic": "AI"}) + + print(b.tree()) + # [agent] Senior Researcher: $1.92 / $2.00 (96.0%) + # [task] research: $1.92 / $1.50 (128.0%) + ``` + +- :material-alert-decagram:{ .lg .middle } **Level-Specific Exceptions** + + --- + + `NodeBudgetExceededError`, `AgentBudgetExceededError`, `TaskBudgetExceededError`, `SessionBudgetExceededError` — all subclass `BudgetExceededError` so existing `except` blocks catch everything. + + ```python + except TaskBudgetExceededError as e: + print(f"Task '{e.task_name}' over budget") + except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' over budget") + except BudgetExceededError: + # catches all budget errors + ... + ``` + +
+ +--- + +## Previous: v0.2.9
diff --git a/docs/integrations/crewai.md b/docs/integrations/crewai.md index 840ed80..5d22dc9 100644 --- a/docs/integrations/crewai.md +++ b/docs/integrations/crewai.md @@ -1,6 +1,6 @@ # CrewAI Integration -Shekel integrates seamlessly with [CrewAI](https://github.com/joaomdmoura/crewAI) to track and enforce budgets on multi-agent workflows. +Shekel integrates with [CrewAI](https://github.com/joaomdmoura/crewAI) to enforce per-agent, per-task, and global spend limits on multi-agent workflows — with zero changes to your crew definition. ## Installation @@ -8,154 +8,130 @@ Shekel integrates seamlessly with [CrewAI](https://github.com/joaomdmoura/crewAI pip install shekel[openai] crewai ``` -## Basic Integration +## Zero-config global cap -Wrap your Crew execution with a budget context: +Wrap any crew execution with `budget()` — shekel auto-detects CrewAI and enforces the cap: ```python from crewai import Agent, Task, Crew -from shekel import budget - -# Define your agent -researcher = Agent( - role='Researcher', - goal='Research and provide accurate information', - backstory='Expert researcher with attention to detail', - verbose=True -) - -# Define task -task = Task( - description='Research the history of artificial intelligence', - agent=researcher, - expected_output='A comprehensive overview of AI history' -) - -# Create crew -crew = Crew( - agents=[researcher], - tasks=[task], - verbose=True -) - -# Execute with budget -with budget(max_usd=1.00) as b: - result = crew.kickoff() - print(f"Crew execution cost: ${b.spent:.4f}") +from shekel import budget, BudgetExceededError + +researcher = Agent(role="Senior Researcher", goal="Find key facts", backstory="...", llm="gpt-4o-mini") +writer = Agent(role="Content Writer", goal="Write a summary", backstory="...", llm="gpt-4o-mini") + +research_task = Task(name="research", description="Research: {topic}", expected_output="3 facts", agent=researcher) +write_task = Task(name="write", description="Write a paragraph summary", expected_output="1 paragraph", agent=writer) + +crew = Crew(agents=[researcher, writer], tasks=[research_task, write_task]) + +try: + with budget(max_usd=5.00) as b: + crew.kickoff(inputs={"topic": "climate change"}) + print(f"Done. Spent: ${b.spent:.4f}") +except BudgetExceededError as e: + print(f"Budget exceeded: {e}") ``` -## Multi-Agent Crews +Shekel patches `Agent.execute_task` transparently at `budget().__enter__()` and restores it at `__exit__()`. No crew or agent changes are needed. -Track costs across multiple agents: +## Per-agent caps + +Register caps keyed by `agent.role` — using the attribute directly eliminates string mismatch risk: ```python -from crewai import Agent, Task, Crew, Process -from shekel import budget - -# Define agents -researcher = Agent( - role='Researcher', - goal='Gather comprehensive information', - backstory='Experienced researcher', -) - -writer = Agent( - role='Writer', - goal='Create engaging content', - backstory='Professional content writer', -) - -editor = Agent( - role='Editor', - goal='Ensure quality and accuracy', - backstory='Meticulous editor', -) - -# Define tasks -research_task = Task( - description='Research AI developments in 2024', - agent=researcher, - expected_output='Research findings' -) - -write_task = Task( - description='Write an article based on research', - agent=writer, - expected_output='Draft article', -) - -edit_task = Task( - description='Edit and finalize the article', - agent=editor, - expected_output='Final article' -) - -# Create crew -crew = Crew( - agents=[researcher, writer, editor], - tasks=[research_task, write_task, edit_task], - process=Process.sequential, - verbose=True -) - -# All agents tracked under one budget -with budget(max_usd=5.00) as b: - result = crew.kickoff() - print(f"Total crew cost: ${b.spent:.4f}") - print(b.summary()) +from shekel.exceptions import AgentBudgetExceededError + +try: + with budget(max_usd=5.00, name="agents") as b: + b.agent(researcher.role, max_usd=2.00) + b.agent(writer.role, max_usd=1.00) + crew.kickoff(inputs={"topic": "quantum computing"}) + print(f"Done. Spent: ${b.spent:.4f}") +except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' exceeded cap: ${e.spent:.4f} / ${e.limit:.2f}") +except BudgetExceededError as e: + print(f"Global budget exceeded: {e}") + +print(b.tree()) +# agents: $2.84 / $5.00 (direct: $0.00) +# [agent] Senior Researcher: $1.92 / $2.00 (96.0%) +# [agent] Content Writer: $0.92 / $1.00 (92.0%) ``` -## Budget Protection for Crews +`AgentBudgetExceededError` is raised **before** the agent executes — the agent body never runs when the cap is already exhausted. It subclasses `BudgetExceededError`, so existing `except BudgetExceededError` blocks catch it automatically. + +## Per-agent + per-task caps -Prevent runaway costs from agent loops: +Combine agent and task caps. Use `task.name` as the key directly: ```python -from shekel import budget, BudgetExceededError +from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError try: - with budget(max_usd=2.00, warn_at=0.8) as b: - result = crew.kickoff() - print(f"Success! Cost: ${b.spent:.4f}") + with budget(max_usd=5.00, name="full") as b: + b.agent(researcher.role, max_usd=2.00) + b.agent(writer.role, max_usd=1.00) + b.task(research_task.name, max_usd=1.50) + b.task(write_task.name, max_usd=0.80) + crew.kickoff(inputs={"topic": "renewable energy"}) + print(f"Done. Spent: ${b.spent:.4f}") +except TaskBudgetExceededError as e: + print(f"Task '{e.task_name}' exceeded cap: ${e.spent:.4f} / ${e.limit:.2f}") +except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' exceeded cap") except BudgetExceededError as e: - print(f"Crew stopped due to budget: ${e.spent:.4f}") + print(f"Global budget exceeded: {e}") ``` -## Fallback Models +Gate order (most specific first): **task cap → agent cap → global budget**. Spend delta from each execution is attributed independently to both the agent and the task — they represent different aggregation views of the same spend. -Use cheaper models when budget is reached: +## Spend breakdown with `b.tree()` ```python -with budget(max_usd=1.00, fallback={"at_pct": 0.8, "model": "gpt-4o-mini"}) as b: - result = crew.kickoff() - - if b.model_switched: - print(f"Switched to cheaper model at ${b.switched_at_usd:.4f}") +print(b.tree()) +# full: $3.42 / $5.00 (direct: $0.00) +# [agent] Senior Researcher: $2.10 / $2.00 (105.0%) +# [agent] Content Writer: $1.32 / $1.00 (132.0%) +# [task] research: $2.10 / $1.50 (140.0%) +# [task] write: $1.32 / $0.80 (165.0%) ``` -## Per-Crew Budgets +`b.tree()` shows live spend for every registered component alongside its cap and utilization percentage. + +## Nested budgets -Different budgets for different crew types: +Per-agent and per-task caps are inherited through the parent chain: ```python -CREW_BUDGETS = { - "research": 0.50, - "writing": 1.00, - "analysis": 2.00, -} - -def run_crew(crew_type: str, crew: Crew): - budget_limit = CREW_BUDGETS.get(crew_type, 1.00) - - with budget(max_usd=budget_limit) as b: - result = crew.kickoff() - return result, b.spent - -result, cost = run_crew("research", research_crew) -print(f"{crew_type} crew cost: ${cost:.4f}") +with budget(max_usd=10.00, name="outer") as outer: + with budget(max_usd=5.00, name="inner") as inner: + inner.agent(researcher.role, max_usd=2.00) + crew.kickoff(inputs={"topic": "AI"}) + # inner spend rolls up to outer automatically +``` + +## Warnings for unnamed tasks + +If a `Task` has no `name` attribute and task caps are registered, shekel emits a `UserWarning`: + +``` +shekel: task has no name (description: 'Research the topic...') — set task.name to apply caps. ``` +Always set `task.name` when using `b.task()` caps to avoid silent misses. + +## Exception hierarchy + +| Exception | Raised when | Fields | +|---|---|---| +| `AgentBudgetExceededError` | Agent cap exhausted | `agent_name`, `spent`, `limit` | +| `TaskBudgetExceededError` | Task cap exhausted | `task_name`, `spent`, `limit` | +| `BudgetExceededError` | Global budget exhausted | `spent`, `limit`, `model` | + +All three subclass `BudgetExceededError`, so a single `except BudgetExceededError` catches them all. + ## Next Steps -- [OpenAI Integration](openai.md) -- [LangGraph Integration](langgraph.md) +- [LangGraph Integration](langgraph.md) — per-node circuit breaking - [Budget Enforcement](../usage/budget-enforcement.md) +- [API Reference](../api-reference.md) diff --git a/docs/quickstart.md b/docs/quickstart.md index de9b38b..ea30316 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -282,18 +282,36 @@ print(f"Graph execution cost: ${b.spent:.4f}") ### CrewAI +Shekel auto-detects CrewAI and enforces caps with zero crew changes. Use `b.agent()` and `b.task()` for per-component circuit breaking: + ```python from crewai import Agent, Task, Crew from shekel import budget +from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError + +researcher = Agent(role="Researcher", goal="...", backstory="...", llm="gpt-4o-mini") +writer = Agent(role="Writer", goal="...", backstory="...", llm="gpt-4o-mini") + +research_task = Task(name="research", description="Research: {topic}", expected_output="...", agent=researcher) +write_task = Task(name="write", description="Write a summary.", expected_output="...", agent=writer) -# Your agents and tasks here -crew = Crew(agents=[agent], tasks=[task]) +crew = Crew(agents=[researcher, writer], tasks=[research_task, write_task]) -with budget(max_usd=2.00) as b: - result = crew.kickoff() -print(f"Crew execution cost: ${b.spent:.4f}") +try: + with budget(max_usd=5.00) as b: + b.agent(researcher.role, max_usd=2.00) # AgentBudgetExceededError if exceeded + b.task(research_task.name, max_usd=1.50) # TaskBudgetExceededError if exceeded + crew.kickoff(inputs={"topic": "AI"}) + print(f"Done. Spent: ${b.spent:.4f}") + print(b.tree()) +except TaskBudgetExceededError as e: + print(f"Task '{e.task_name}' exceeded cap") +except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' exceeded cap") ``` +See [CrewAI Integration](integrations/crewai.md) for the full reference. + ## Viewing Spend Summary Get a detailed breakdown of your spending: diff --git a/examples/crewai_demo.py b/examples/crewai_demo.py index 067996a..ca68e7f 100644 --- a/examples/crewai_demo.py +++ b/examples/crewai_demo.py @@ -1,17 +1,25 @@ # Requires: pip install shekel[openai] crewai """ -CrewAI example: budget enforcement across a multi-agent crew. +CrewAI demo: per-agent and per-task circuit breaking with shekel. + +Shekel patches Agent.execute_task transparently — every agent execution +gets a pre-execution budget gate with no crew or agent changes needed. + +Shows four patterns: +1. Global cap only — zero config, shekel auto-detects CrewAI +2. Per-agent caps — b.agent(agent.role, max_usd=X) +3. Per-agent + per-task caps — combined enforcement +4. b.tree() — full spend breakdown after execution """ import os -from shekel import BudgetExceededError, budget - def main() -> None: try: from crewai import Agent, Crew, Task # type: ignore[import] - except ImportError: + except ImportError as e: + print(f"Missing dependency: {e}") print("Run: pip install shekel[openai] crewai") return @@ -19,52 +27,102 @@ def main() -> None: print("Set OPENAI_API_KEY to run this demo.") return + from shekel import BudgetExceededError, budget + from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError + researcher = Agent( - role="Research Analyst", - goal="Find key facts about the given topic", - backstory="You are an expert researcher who finds concise, accurate information.", + role="Senior Researcher", + goal="Find key facts about the topic", + backstory="Expert researcher with broad knowledge.", llm="gpt-4o-mini", verbose=False, ) - writer = Agent( role="Content Writer", - goal="Write a short summary based on research findings", - backstory="You write clear, engaging summaries.", + goal="Summarize research into a clear paragraph", + backstory="Skilled writer who distills complex ideas.", llm="gpt-4o-mini", verbose=False, ) research_task = Task( - description="Research 3 key facts about Python programming language.", - expected_output="A bullet list of 3 facts.", + name="research", + description="Research the topic: {topic}", + expected_output="A bullet list of 3 key facts.", agent=researcher, ) - write_task = Task( - description="Write a 2-sentence summary based on the research.", - expected_output="A 2-sentence summary.", + name="write", + description="Write a one-paragraph summary of the research.", + expected_output="A single paragraph.", agent=writer, ) crew = Crew(agents=[researcher, writer], tasks=[research_task, write_task], verbose=False) # ------------------------------------------------------------------ - # Run crew with budget cap + # 1. Global cap only — zero config, shekel auto-detects CrewAI + # ------------------------------------------------------------------ + print("=== Global cap only ===") + # No configuration needed — shekel auto-detects CrewAI and enforces the cap + try: + with budget(max_usd=5.00, name="global") as b: + crew.kickoff(inputs={"topic": "climate change"}) + print(f"Done. Spent: ${b.spent:.4f}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 2. Per-agent caps — use agent.role as the key directly # ------------------------------------------------------------------ - print("=== CrewAI with budget ===") + print("\n=== Per-agent caps ===") + try: + with budget(max_usd=5.00, name="agents") as b: + # Use agent.role directly — eliminates key mismatch risk + b.agent(researcher.role, max_usd=2.00) + b.agent(writer.role, max_usd=1.00) + crew.kickoff(inputs={"topic": "quantum computing"}) + print(f"Done. Spent: ${b.spent:.4f}") + except AgentBudgetExceededError as e: + print(f"Agent cap exceeded: {e}") + except BudgetExceededError as e: + print(f"Global budget exceeded: {e}") - def on_warn(spent: float, limit: float) -> None: - print(f" Warning: ${spent:.4f} of ${limit:.2f} used") + print(b.tree()) + # agents: $X.XX / $5.00 (direct: $X.XX) + # [agent] Senior Researcher: $X.XX / $2.00 (X%) + # [agent] Content Writer: $X.XX / $1.00 (X%) + # ------------------------------------------------------------------ + # 3. Per-agent + per-task caps + # ------------------------------------------------------------------ + print("\n=== Per-agent + per-task caps ===") try: - with budget(max_usd=0.50, warn_at=0.8, on_warn=on_warn) as b: - result = crew.kickoff() - print(result) - print(f"\nCrew cost: ${b.spent:.4f}") - print(b.summary()) + with budget(max_usd=5.00, name="full") as b: + b.agent(researcher.role, max_usd=2.00) + b.agent(writer.role, max_usd=1.00) + # Use task.name directly — eliminates key mismatch risk + b.task(research_task.name, max_usd=1.50) + b.task(write_task.name, max_usd=0.80) + crew.kickoff(inputs={"topic": "renewable energy"}) + print(f"Done. Spent: ${b.spent:.4f}") + except TaskBudgetExceededError as e: + print(f"Task cap exceeded: {e}") + except AgentBudgetExceededError as e: + print(f"Agent cap exceeded: {e}") except BudgetExceededError as e: - print(f"Crew exceeded budget: {e}") + print(f"Global budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 4. b.tree() — full spend breakdown + # ------------------------------------------------------------------ + print("\n=== Spend breakdown ===") + print(b.tree()) + # full: $X.XX / $5.00 (direct: $X.XX) + # [agent] Senior Researcher: $X.XX / $2.00 (X%) + # [agent] Content Writer: $X.XX / $1.00 (X%) + # [task] research: $X.XX / $1.50 (X%) + # [task] write: $X.XX / $0.80 (X%) if __name__ == "__main__": diff --git a/examples/distributed_budgets_demo.py b/examples/distributed_budgets_demo.py new file mode 100644 index 0000000..a3c5fb4 --- /dev/null +++ b/examples/distributed_budgets_demo.py @@ -0,0 +1,117 @@ +# Requires: pip install shekel[openai,redis] +""" +Distributed budgets demo: enforce LLM cost limits across multiple processes. + +RedisBackend uses an atomic Lua script (one round-trip) to enforce rolling-window +caps across any number of workers hitting the same Redis instance. + +Shows three patterns: +1. Basic distributed enforcement with RedisBackend +2. Multi-cap spec — simultaneous USD + call-count rolling windows +3. Distributed budget + per-node caps (LangGraph) +""" + +import os + + +def main() -> None: + redis_url = os.environ.get("REDIS_URL", "redis://localhost:6379") + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + print("Set OPENAI_API_KEY to run this demo.") + return + + try: + import openai + from shekel.backends.redis import RedisBackend + except ImportError as e: + print(f"Missing dependency: {e}") + print("Run: pip install shekel[openai] redis") + return + + from shekel import BudgetExceededError, budget + + client = openai.OpenAI(api_key=api_key) + + # ------------------------------------------------------------------ + # 1. Basic distributed enforcement + # ------------------------------------------------------------------ + print("=== Distributed enforcement ===") + try: + backend = RedisBackend(url=redis_url) + with budget(max_usd=5.00, name="shared-pool", backend=backend) as b: + resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "What is 2+2?"}], + max_tokens=20, + ) + print(f"Answer: {resp.choices[0].message.content}") + print(f"Spent: ${b.spent:.4f} / $5.00") + except BudgetExceededError as e: + print(f"Distributed budget exceeded: {e}") + except Exception as e: + print(f"Redis unavailable: {e}") + print("Start Redis with: docker run -p 6379:6379 redis:alpine") + return + + # ------------------------------------------------------------------ + # 2. Multi-cap rolling-window spec + # ------------------------------------------------------------------ + print("\n=== Multi-cap: $5/hr + 100 calls/hr ===") + try: + backend = RedisBackend(url=redis_url) + with budget("$5/hr + 100 calls/hr", name="api-tier", backend=backend) as b: + for i in range(3): + resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": f"Question {i+1}: capital of France?"}], + max_tokens=10, + ) + print(f" [{i+1}] {resp.choices[0].message.content}") + except BudgetExceededError as e: + # e.retry_after tells callers when the window resets + retry = getattr(e, "retry_after", None) + print(f"Rate limit hit. Retry after: {retry:.1f}s" if retry else f"Limit: {e}") + + # ------------------------------------------------------------------ + # 3. Distributed budget + per-node LangGraph caps + # ------------------------------------------------------------------ + print("\n=== Distributed + per-node caps ===") + try: + from langgraph.graph import END, StateGraph # type: ignore[import] + from typing_extensions import TypedDict + + class State(TypedDict): + query: str + answer: str + + def answer_node(state: State) -> State: + resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": state["query"]}], + max_tokens=40, + ) + return {**state, "answer": resp.choices[0].message.content or ""} + + graph = StateGraph(State) + graph.add_node("answer", answer_node) + graph.set_entry_point("answer") + graph.add_edge("answer", END) + app = graph.compile() + + backend = RedisBackend(url=redis_url) + with budget("$5/hr", name="graph-pool", backend=backend) as b: + b.node("answer", max_usd=0.10) + result = app.invoke({"query": "Name a planet.", "answer": ""}) + print(f"Answer: {result['answer']}") + + print(b.tree()) + + except ImportError: + print("langgraph not installed — pip install langgraph") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/langchain_demo.py b/examples/langchain_demo.py new file mode 100644 index 0000000..d7c8b07 --- /dev/null +++ b/examples/langchain_demo.py @@ -0,0 +1,101 @@ +# Requires: pip install shekel[openai] langchain langchain-openai +""" +LangChain demo: per-chain circuit breaking with shekel. + +Shekel patches Runnable._call_with_config and RunnableSequence.invoke +transparently — every chain gets a pre-execution budget gate with no +chain changes needed. + +Shows three patterns: +1. Per-chain USD caps — ChainBudgetExceededError before the chain runs +2. Nested budgets with chain caps — tree() shows spend breakdown +3. Combining node caps (LangGraph) and chain caps (LangChain) +""" + +import os + + +def main() -> None: + try: + from langchain_core.output_parsers import StrOutputParser # type: ignore[import] + from langchain_core.prompts import ChatPromptTemplate # type: ignore[import] + from langchain_openai import ChatOpenAI # type: ignore[import] + except ImportError as e: + print(f"Missing dependency: {e}") + print("Run: pip install shekel[openai] langchain langchain-openai") + return + + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + print("Set OPENAI_API_KEY to run this demo.") + return + + from shekel import BudgetExceededError, budget + from shekel.providers.langchain import ChainBudgetExceededError # type: ignore[import] + + llm = ChatOpenAI(model="gpt-4o-mini", api_key=api_key, max_tokens=80) + + retriever_prompt = ChatPromptTemplate.from_template("Find facts about: {topic}") + summarizer_prompt = ChatPromptTemplate.from_template("Summarize in one sentence: {text}") + + retriever_chain = retriever_prompt | llm | StrOutputParser() + summarizer_chain = summarizer_prompt | llm | StrOutputParser() + + # ------------------------------------------------------------------ + # 1. Per-chain USD caps + # ------------------------------------------------------------------ + print("=== Per-chain caps ===") + try: + with budget(max_usd=5.00, name="pipeline") as b: + b.chain("retriever", max_usd=0.20) + b.chain("summarizer", max_usd=1.00) + + facts = retriever_chain.invoke({"topic": "climate change"}) + summary = summarizer_chain.invoke({"text": facts}) + print(f"Summary: {summary}") + except (BudgetExceededError, ChainBudgetExceededError) as e: + print(f"Budget exceeded: {e}") + + print(b.tree()) + # pipeline: $X.XX / $5.00 + # [chain] retriever: $X.XX / $0.20 (X%) + # [chain] summarizer: $X.XX / $1.00 (X%) + + # ------------------------------------------------------------------ + # 2. Chain cap exceeded — ChainBudgetExceededError + # ------------------------------------------------------------------ + print("\n=== Chain cap exceeded ===") + try: + with budget(max_usd=5.00, name="tight") as b: + b.chain("retriever", max_usd=0.00001) # intentionally tiny + retriever_chain.invoke({"topic": "AI"}) + except ChainBudgetExceededError as e: + print(f"Chain '{e.chain_name}' exceeded: ${e.spent:.6f} > ${e.limit:.6f}") + except BudgetExceededError as e: + print(f"Global budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 3. Nested budgets — per-stage isolation + # ------------------------------------------------------------------ + print("\n=== Nested per-stage budgets ===") + workflow = budget(max_usd=10.00, name="workflow") + try: + with workflow: + with budget(max_usd=2.00, name="research"): + facts = retriever_chain.invoke({"topic": "quantum computing"}) + + with budget(max_usd=3.00, name="writing"): + summary = summarizer_chain.invoke({"text": facts}) + + print(f"Result: {summary}") + except BudgetExceededError as e: + print(f"Budget exceeded: {e}") + + print(workflow.tree()) + # workflow: $X.XX / $10.00 + # research: $X.XX / $2.00 + # writing: $X.XX / $3.00 + + +if __name__ == "__main__": + main() diff --git a/examples/langgraph_demo.py b/examples/langgraph_demo.py index 4c0250a..318f3fe 100644 --- a/examples/langgraph_demo.py +++ b/examples/langgraph_demo.py @@ -1,14 +1,14 @@ -# Requires: pip install shekel[openai] +# Requires: pip install shekel[openai] langgraph typing_extensions """ -LangGraph demo: budget enforcement with shekel. +LangGraph demo: per-node circuit breaking with shekel. -Shekel works with LangGraph out of the box — just wrap with budget(). -All LLM calls inside graph nodes are automatically tracked. +Shekel patches StateGraph.add_node() transparently — every node gets a +pre-execution budget gate with no graph changes needed. Shows three patterns: -1. Basic budget enforcement -2. Fallback model when budget threshold is reached -3. Nested budgets for multi-node graphs +1. Per-node USD caps — NodeBudgetExceededError before the node runs +2. Global + per-node caps together — tree() shows spend breakdown +3. Distributed enforcement with RedisBackend (optional) """ import os @@ -30,53 +30,93 @@ def main() -> None: return from shekel import BudgetExceededError, budget + from shekel.providers.langgraph import NodeBudgetExceededError # type: ignore[import] client = openai.OpenAI(api_key=api_key) class State(TypedDict): - question: str - answer: str + query: str + data: str + summary: str - def call_llm(state: State) -> State: - response = client.chat.completions.create( + def fetch_data(state: State) -> State: + resp = client.chat.completions.create( model="gpt-4o-mini", - messages=[{"role": "user", "content": state["question"]}], - max_tokens=50, + messages=[{"role": "user", "content": f"Find facts about: {state['query']}"}], + max_tokens=100, ) - return {"question": state["question"], "answer": response.choices[0].message.content} + return {**state, "data": resp.choices[0].message.content or ""} + + def summarize(state: State) -> State: + resp = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": f"Summarize: {state['data']}"}], + max_tokens=60, + ) + return {**state, "summary": resp.choices[0].message.content or ""} graph = StateGraph(State) - graph.add_node("llm", call_llm) - graph.set_entry_point("llm") - graph.add_edge("llm", END) + graph.add_node("fetch_data", fetch_data) + graph.add_node("summarize", summarize) + graph.set_entry_point("fetch_data") + graph.add_edge("fetch_data", "summarize") + graph.add_edge("summarize", END) app = graph.compile() # ------------------------------------------------------------------ - # 1. Basic budget enforcement + # 1. Per-node USD caps with global budget # ------------------------------------------------------------------ - print("=== Basic budget enforcement ===") + print("=== Per-node caps ===") try: - with budget(max_usd=0.10, name="demo", warn_at=0.8) as b: - result = app.invoke({"question": "What is 2+2?", "answer": ""}) - print(f"Answer: {result['answer']}") - print(f"Spent: ${b.spent:.4f} / ${b.limit:.2f}") - except BudgetExceededError as e: + with budget(max_usd=5.00, name="pipeline") as b: + b.node("fetch_data", max_usd=0.50) + b.node("summarize", max_usd=1.00) + + result = app.invoke({"query": "climate change", "data": "", "summary": ""}) + print(f"Summary: {result['summary']}") + except (BudgetExceededError, NodeBudgetExceededError) as e: print(f"Budget exceeded: {e}") + print(b.tree()) + # pipeline: $X.XX / $5.00 + # [node] fetch_data: $X.XX / $0.50 (X%) + # [node] summarize: $X.XX / $1.00 (X%) + # ------------------------------------------------------------------ - # 2. Fallback model when threshold is reached + # 2. Per-node cap exceeded — NodeBudgetExceededError # ------------------------------------------------------------------ - print("\n=== Fallback model ===") - with budget( - max_usd=0.001, - name="fallback-demo", - fallback={"at_pct": 0.5, "model": "gpt-4o-mini"}, - ) as b: - result = app.invoke({"question": "What is the capital of France?", "answer": ""}) - print(f"Answer: {result['answer']}") - if b.model_switched: - print(f"Switched to fallback at ${b.switched_at_usd:.6f}") - print(f"Total: ${b.spent:.4f}") + print("\n=== Node cap exceeded ===") + try: + with budget(max_usd=5.00, name="tight") as b: + b.node("fetch_data", max_usd=0.0001) # intentionally tiny cap + app.invoke({"query": "AI trends", "data": "", "summary": ""}) + except NodeBudgetExceededError as e: + print(f"Node '{e.node_name}' exceeded: ${e.spent:.6f} > ${e.limit:.6f}") + except BudgetExceededError as e: + print(f"Global budget exceeded: {e}") + + # ------------------------------------------------------------------ + # 3. Distributed enforcement (Redis — optional) + # ------------------------------------------------------------------ + redis_url = os.environ.get("REDIS_URL") + if redis_url: + print("\n=== Distributed budget (Redis) ===") + try: + from shekel.backends.redis import RedisBackend + + backend = RedisBackend() + with budget("$5/hr + 100 calls/hr", name="distributed-pipeline", backend=backend) as b: + b.node("fetch_data", max_usd=0.50) + b.node("summarize", max_usd=1.00) + result = app.invoke({"query": "quantum computing", "data": "", "summary": ""}) + print(f"Summary: {result['summary']}") + print(b.tree()) + except ImportError: + print("redis package not installed — pip install shekel[redis]") + except Exception as e: + print(f"Redis error: {e}") + else: + print("\n(Set REDIS_URL to demo distributed enforcement)") if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 5998080..1dcfda5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "shekel" -version = "0.2.9" +version = "0.3.1" description = "LLM budget enforcement and cost tracking. Zero config — with budget(max_usd=1.00): run_agent(). Or: shekel run agent.py --budget 1. Works with LangGraph, CrewAI, raw OpenAI/Anthropic/Gemini." readme = "README.md" license = { file = "LICENSE" } @@ -79,7 +79,8 @@ litellm = ["litellm>=1.0.0"] gemini = ["google-genai>=1.0.0"] huggingface = ["huggingface-hub>=0.20.0"] otel = ["opentelemetry-api>=1.0.0"] -all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0", "google-genai>=1.0.0", "huggingface-hub>=0.20.0", "opentelemetry-api>=1.0.0"] +redis = ["redis>=4.0.0"] +all = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "litellm>=1.0.0", "google-genai>=1.0.0", "huggingface-hub>=0.20.0", "opentelemetry-api>=1.0.0", "redis>=4.0.0"] all-models = ["openai>=1.0.0", "anthropic>=0.7.0", "langfuse>=2.0.0", "tokencost>=0.1.0"] cli = ["click>=8.0.0"] dev = [ @@ -99,6 +100,9 @@ dev = [ "huggingface-hub>=0.20.0", "opentelemetry-api>=1.0.0", "opentelemetry-sdk>=1.0.0", + "redis>=4.0.0", + "fakeredis>=2.0.0", + "testcontainers[redis]>=4.0.0", ] [project.scripts] @@ -184,6 +188,14 @@ ignore_missing_imports = true module = ["langchain_core", "langchain_core.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["redis", "redis.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = ["fakeredis", "fakeredis.*"] +ignore_missing_imports = true + [[tool.mypy.overrides]] module = ["crewai", "crewai.*"] ignore_missing_imports = true diff --git a/shekel/__init__.py b/shekel/__init__.py index 828c6ee..2b99f9a 100644 --- a/shekel/__init__.py +++ b/shekel/__init__.py @@ -5,19 +5,39 @@ from shekel._budget import Budget from shekel._decorator import with_budget from shekel._tool import tool -from shekel.exceptions import BudgetExceededError, ToolBudgetExceededError +from shekel.exceptions import ( + AgentBudgetExceededError, + BudgetConfigMismatchError, + BudgetExceededError, + ChainBudgetExceededError, + NodeBudgetExceededError, + SessionBudgetExceededError, + TaskBudgetExceededError, + ToolBudgetExceededError, +) -__version__ = "0.2.9" +__version__ = "0.3.1" __all__ = [ "budget", "Budget", "TemporalBudget", "with_budget", "BudgetExceededError", + "BudgetConfigMismatchError", "ToolBudgetExceededError", + "NodeBudgetExceededError", + "AgentBudgetExceededError", + "TaskBudgetExceededError", + "SessionBudgetExceededError", + "ChainBudgetExceededError", "tool", ] +# Cap-related kwargs that must NOT be mixed with a spec string. +_CAP_KWARGS = frozenset( + {"max_usd", "max_llm_calls", "max_tool_calls", "max_tokens", "window_seconds"} +) + def budget( spec: str | None = None, @@ -27,24 +47,33 @@ def budget( ) -> Budget: """Factory for creating Budget or TemporalBudget instances. - Usage:: + Two forms — never mixed:: - # Temporal (rolling-window) budget from spec string: + # Spec string (per-cap windows, richer form): b = budget("$5/hr", name="api") + b = budget("$5/hr + 100 calls/hr", name="api") - # Temporal budget from kwargs: + # Kwargs (single shared window, convenience shorthand): b = budget(max_usd=5.0, window_seconds=3600, name="api") + b = budget(max_usd=5.0, max_llm_calls=100, window_seconds=3600, name="api") - # Regular budget (backward-compatible): + # Regular budget (no rolling window): b = budget(max_usd=5.0) """ - from shekel._temporal import TemporalBudget, _parse_spec + from shekel._temporal import TemporalBudget, _parse_cap_spec if spec is not None: - max_usd, window_seconds = _parse_spec(spec) + # Form-mixing guard: spec string + cap/window kwargs is an error. + mixed = _CAP_KWARGS.intersection(kwargs) + if mixed: + raise ValueError( + f"Cannot mix spec string with cap/window kwargs: {sorted(mixed)}. " + "Use either the spec-string form or the kwargs form, never both." + ) if not name: raise ValueError('budget(spec) requires name=, e.g. budget("$5/hr", name="api")') - return TemporalBudget(max_usd=max_usd, window_seconds=window_seconds, name=name, **kwargs) + caps = _parse_cap_spec(spec) + return TemporalBudget(caps=caps, name=name, **kwargs) window_seconds = kwargs.pop("window_seconds", None) if window_seconds is not None: diff --git a/shekel/_budget.py b/shekel/_budget.py index c44edf0..99993ad 100644 --- a/shekel/_budget.py +++ b/shekel/_budget.py @@ -4,13 +4,25 @@ import time import warnings from contextvars import Token -from typing import TYPE_CHECKING, Any, Callable, TypedDict +from dataclasses import dataclass, field +from typing import Any, Callable, TypedDict from shekel import _context, _patch from shekel.exceptions import BudgetExceededError, ToolBudgetExceededError -if TYPE_CHECKING: - pass + +@dataclass +class ComponentBudget: + """Lightweight cap tracker for a named node, agent, or task (v0.3.1). + + Stores the declared USD limit and accumulated spend for a single + framework component. Used by framework adapters (LangGraph, CrewAI, + OpenClaw) to enforce per-component circuit breaking. + """ + + name: str + max_usd: float + _spent: float = field(default=0.0, init=False) class CallRecord(TypedDict): @@ -187,6 +199,13 @@ def __init__( self._tool_calls: list[ToolCallRecord] = [] self._tool_warn_fired: bool = False + # Component budget support (v0.3.1) + self._node_budgets: dict[str, ComponentBudget] = {} + self._agent_budgets: dict[str, ComponentBudget] = {} + self._task_budgets: dict[str, ComponentBudget] = {} + self._chain_budgets: dict[str, ComponentBudget] = {} + self._runtime: Any = None + # ------------------------------------------------------------------ # Internal state reset # ------------------------------------------------------------------ @@ -294,6 +313,10 @@ def __enter__(self) -> Budget: self._effective_tool_call_limit = self.max_tool_calls self._ctx_token = _context.set_active_budget(self) + from shekel._runtime import ShekelRuntime + + self._runtime = ShekelRuntime(self) + self._runtime.probe() return self def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -345,6 +368,9 @@ def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: from shekel._context import _active_budget _active_budget.reset(self._ctx_token) + if self._runtime is not None: + self._runtime.release() + self._runtime = None _patch.remove_patches() # returning None (not False) — never suppress exceptions @@ -428,6 +454,10 @@ async def __aenter__(self) -> Budget: self._effective_tool_call_limit = self.max_tool_calls self._ctx_token = _context.set_active_budget(self) + from shekel._runtime import ShekelRuntime + + self._runtime = ShekelRuntime(self) + self._runtime.probe() return self async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: @@ -478,6 +508,9 @@ async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> from shekel._context import _active_budget _active_budget.reset(self._ctx_token) + if self._runtime is not None: + self._runtime.release() + self._runtime = None _patch.remove_patches() # ------------------------------------------------------------------ @@ -910,8 +943,80 @@ def tree(self, _indent: int = 0) -> str: else: lines.append(child.tree(_indent=_indent + 1)) + for kind, budgets in [ + ("node", self._node_budgets), + ("agent", self._agent_budgets), + ("task", self._task_budgets), + ]: + for comp_name, cb in budgets.items(): + limit_str = f"${cb.max_usd:.2f}" + pct = f"{cb._spent / cb.max_usd * 100:.1f}%" if cb.max_usd else "n/a" + lines.append( + f"{prefix} [{kind}] {comp_name}: ${cb._spent:.4f} / {limit_str} ({pct})" + ) + return "\n".join(lines) + # ------------------------------------------------------------------ + # Component budget API (v0.3.1) + # ------------------------------------------------------------------ + + def node(self, name: str, max_usd: float) -> Budget: + """Register an explicit USD cap for a LangGraph node. + + Returns ``self`` for fluent chaining:: + + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=0.50).node("summarize", max_usd=1.00) + graph.invoke(...) + """ + if max_usd <= 0: + raise ValueError(f"node max_usd must be positive, got {max_usd}") + self._node_budgets[name] = ComponentBudget(name=name, max_usd=max_usd) + return self + + def agent(self, name: str, max_usd: float) -> Budget: + """Register an explicit USD cap for an agent (CrewAI / OpenClaw). + + Returns ``self`` for fluent chaining:: + + with budget(max_usd=10.00) as b: + b.agent("researcher", max_usd=3.00).agent("writer", max_usd=2.00) + crew.kickoff() + """ + if max_usd <= 0: + raise ValueError(f"agent max_usd must be positive, got {max_usd}") + self._agent_budgets[name] = ComponentBudget(name=name, max_usd=max_usd) + return self + + def task(self, name: str, max_usd: float) -> Budget: + """Register an explicit USD cap for a task (CrewAI). + + Returns ``self`` for fluent chaining:: + + with budget(max_usd=5.00) as b: + b.task("research", max_usd=1.00).task("write", max_usd=0.50) + crew.kickoff() + """ + if max_usd <= 0: + raise ValueError(f"task max_usd must be positive, got {max_usd}") + self._task_budgets[name] = ComponentBudget(name=name, max_usd=max_usd) + return self + + def chain(self, name: str, max_usd: float) -> Budget: + """Register an explicit USD cap for a named LangChain chain or runnable. + + Returns ``self`` for fluent chaining:: + + with budget(max_usd=5.00) as b: + b.chain("summarize", max_usd=0.50).chain("research", max_usd=1.00) + result = summarize_chain.invoke(inputs) + """ + if max_usd <= 0: + raise ValueError(f"chain max_usd must be positive, got {max_usd}") + self._chain_budgets[name] = ComponentBudget(name=name, max_usd=max_usd) + return self + # ------------------------------------------------------------------ # Summary # ------------------------------------------------------------------ diff --git a/shekel/_runtime.py b/shekel/_runtime.py new file mode 100644 index 0000000..72a52f5 --- /dev/null +++ b/shekel/_runtime.py @@ -0,0 +1,88 @@ +"""ShekelRuntime — framework detection and adapter wiring (v0.3.1). + +Probed once at budget open. Activates any installed framework adapters +(LangGraph, CrewAI, OpenClaw, ...). The adapter list starts empty; later +phases register into it via ShekelRuntime.register(). +""" + +from __future__ import annotations + +from typing import Any + + +class ShekelRuntime: + """Detects installed framework adapters and wires them at budget open/close. + + Usage (internal — called by Budget.__enter__ / __exit__):: + + runtime = ShekelRuntime(budget_instance) + runtime.probe() # on __enter__ + runtime.release() # on __exit__ + + Framework adapters are registered once at import time by each phase:: + + ShekelRuntime.register(LangGraphAdapter) # v0.3.2 + ShekelRuntime.register(CrewAIAdapter) # v0.3.3 + """ + + _adapter_registry: list[type[Any]] = [] + + def __init__(self, budget: Any) -> None: + self._budget = budget + self._active_adapters: list[Any] = [] + + def probe(self) -> None: + """Activate all registered framework adapters whose packages are installed. + + Called once on ``budget.__enter__()``. Adapters that raise + ``ImportError`` are silently skipped (framework not installed). + """ + for adapter_cls in ShekelRuntime._adapter_registry: + try: + adapter = adapter_cls() + adapter.install_patches(self._budget) + self._active_adapters.append(adapter) + except ImportError: + pass # framework not installed — silent skip + + def release(self) -> None: + """Deactivate all adapters that were activated by probe(). + + Called once on ``budget.__exit__()``. Exceptions from + ``remove_patches()`` are suppressed to avoid masking the original + exception during budget exit. + """ + for adapter in self._active_adapters: + try: + adapter.remove_patches(self._budget) + except Exception: # pragma: no cover — defensive cleanup + pass + self._active_adapters.clear() + + @classmethod + def register(cls, adapter_cls: type[Any]) -> None: + """Register a framework adapter class. + + Called once per phase at module import time:: + + ShekelRuntime.register(LangGraphAdapter) + """ + cls._adapter_registry.append(adapter_cls) + + +# --------------------------------------------------------------------------- +# Built-in framework adapters — registered once at import time +# --------------------------------------------------------------------------- + + +def _register_builtin_adapters() -> None: + from shekel.providers.crewai import CrewAIExecutionAdapter # noqa: PLC0415 + from shekel.providers.langchain import LangChainRunnerAdapter # noqa: PLC0415 + from shekel.providers.langgraph import LangGraphAdapter # noqa: PLC0415 + + ShekelRuntime.register(LangGraphAdapter) + ShekelRuntime.register(LangChainRunnerAdapter) + ShekelRuntime.register(CrewAIExecutionAdapter) + + +_register_builtin_adapters() diff --git a/shekel/_temporal.py b/shekel/_temporal.py index 5eff255..e518d38 100644 --- a/shekel/_temporal.py +++ b/shekel/_temporal.py @@ -15,14 +15,104 @@ _UNIT_SECONDS: dict[str, int] = {"s": 1, "sec": 1, "min": 60, "hr": 3600, "h": 3600} _CALENDAR_UNITS = {"day", "days", "week", "weeks", "month", "months"} + +# Maps spec-string type tokens → internal counter names. +# Checked case-insensitively after lowercasing. +_CAP_TYPE_MAP: dict[str, str] = { + "usd": "usd", + "call": "llm_calls", + "calls": "llm_calls", + "tool": "tool_calls", + "tools": "tool_calls", + "token": "tokens", + "tokens": "tokens", +} + +# Matches one cap term, e.g.: "$5/hr", "5 usd/hr", "100 calls/30min", "20 tools/1hr" +# Groups: usd_dollar (from $N), gen_num + gen_type (from N calls/tools/etc.), +# window_count (optional multiplier), unit +_CAP_TERM_RE = re.compile( + r"^\s*" + r"(?:" + r" \$(?P[\d.]+)" + r" |(?P[\d.]+)\s+(?Pusd|calls?|tools?|tokens?)" + r")" + r"\s*(?:/\s*|\s+per\s+|\s+)" + r"(?P[\d.]*)\s*" + r"(?Psec|min|hr|h|s)\b" + r"\s*$", + re.IGNORECASE | re.VERBOSE, +) + +# Legacy single-cap regex kept for _parse_spec backward compat. _SPEC_RE = re.compile( r"^\$?(?P[\d.]+)\s*(?:per\s+)?(?P[\d.]*)\s*(?P\w+)$", re.IGNORECASE, ) +# Ordered check sequence: usd is checked before llm_calls before tool_calls. +_CHECK_ORDER = ("usd", "llm_calls", "tool_calls", "tokens") + + +def _parse_one_cap_term(term: str) -> tuple[str, float | None, float]: + """Parse a single cap term like '$5/hr' or '100 calls/30min'. + + Returns: + (counter_name, limit, window_seconds) + """ + m = _CAP_TERM_RE.match(term.strip()) + if not m: + raise ValueError(f"Cannot parse cap term: {term!r}") + + if m.group("usd_dollar") is not None: + counter = "usd" + amount = float(m.group("usd_dollar")) + else: + gen_type = m.group("gen_type").lower() + # strip trailing 's' handled by the map + mapped = _CAP_TYPE_MAP.get(gen_type) + if mapped is None: # pragma: no cover — regex restricts gen_type to known tokens + raise ValueError(f"Unknown cap type: {m.group('gen_type')!r}") + counter = mapped + amount = float(m.group("gen_num")) + + if amount <= 0: + raise ValueError(f"Cap amount must be > 0, got {amount}") + + unit = m.group("unit").lower() + if unit in _CALENDAR_UNITS: # pragma: no cover — regex restricts unit to known tokens + raise ValueError(f"Calendar unit {unit!r} not supported. Use 's', 'min', or 'hr'.") + if unit not in _UNIT_SECONDS: # pragma: no cover — regex restricts unit to known tokens + raise ValueError(f"Unknown time unit: {unit!r}") + + count_str = m.group("window_count") + count = float(count_str) if count_str else 1.0 + window_s = count * _UNIT_SECONDS[unit] + return counter, amount, window_s + + +def _parse_cap_spec(spec: str) -> list[tuple[str, float | None, float]]: + """Parse a (possibly multi-cap) spec string into a list of cap tuples. + + Examples:: + + _parse_cap_spec("$5/hr") # [("usd", 5.0, 3600.0)] + _parse_cap_spec("100 calls/hr") # [("llm_calls", 100.0, 3600.0)] + _parse_cap_spec("$5/hr + 100 calls/30min") # two caps, different windows + + Returns: + List of (counter_name, limit, window_seconds) tuples. + """ + terms = [t.strip() for t in spec.split("+")] + return [_parse_one_cap_term(t) for t in terms if t] + def _parse_spec(spec: str) -> tuple[float, float]: - """Parse "$5/hr" or "$5 per 30min" -> (max_usd, window_seconds).""" + """Parse "$5/hr" or "$5 per 30min" -> (max_usd, window_seconds). + + Backward-compatible single-USD-cap parser. + For multi-cap specs, use _parse_cap_spec() directly. + """ normalized = spec.strip().replace("/", " ") m = _SPEC_RE.match(normalized) if not m: @@ -42,80 +132,187 @@ def _parse_spec(spec: str) -> tuple[float, float]: @runtime_checkable class TemporalBudgetBackend(Protocol): - def get_state(self, budget_name: str) -> tuple[float, float | None]: - """Return (spent_usd, window_start_monotonic) for the named budget.""" - pass + """Generic named-counter backend protocol for rolling-window budgets. + + The backend is unaware of USD vs. calls — it manages named counters + with per-counter limits and windows. All-or-nothing atomicity: if any + counter would exceed its limit, none are incremented. + """ def check_and_add( self, budget_name: str, - amount: float, - max_usd: float, - window_seconds: float, - ) -> bool: - """Atomically check limit and add amount. Returns False if it would exceed.""" - pass + amounts: dict[str, float], + limits: dict[str, float | None], + windows: dict[str, float], + ) -> tuple[bool, str | None]: + """Atomically check limits and add amounts. + + Args: + budget_name: Unique identifier for this budget. + amounts: Counter increments for this call, e.g. {"usd": 0.03, "llm_calls": 1}. + limits: Per-counter caps; None means tracked but uncapped. + windows: Per-counter window durations in seconds. + + Returns: + (allowed, exceeded_counter_name_or_None). + If allowed is False, exceeded_counter names the first counter that + would have exceeded (checked in deterministic order). + """ + pass # pragma: no cover — Protocol stub; implemented by concrete backends + + def get_state(self, budget_name: str) -> dict[str, float]: + """Return current-window spend for each counter.""" + pass # pragma: no cover — Protocol stub; implemented by concrete backends def reset(self, budget_name: str) -> None: - """Reset the window state for the given budget name.""" - pass + """Reset all counters for the given budget name.""" + pass # pragma: no cover — Protocol stub; implemented by concrete backends class InMemoryBackend: """Simple in-process rolling-window backend. NOT thread-safe — each thread/task should use its own budget instance. + Implements the generic TemporalBudgetBackend protocol. """ def __init__(self) -> None: - self._state: dict[str, tuple[float, float | None]] = {} - - def get_state(self, budget_name: str) -> tuple[float, float | None]: - return self._state.get(budget_name, (0.0, None)) + # {budget_name: {counter: (spent, window_start_monotonic | None)}} + self._state: dict[str, dict[str, tuple[float, float | None]]] = {} def check_and_add( self, budget_name: str, - amount: float, - max_usd: float, - window_seconds: float, - ) -> bool: - spent, window_start = self.get_state(budget_name) + amounts: dict[str, float], + limits: dict[str, float | None], + windows: dict[str, float], + ) -> tuple[bool, str | None]: + """Atomically check limits and add amounts (all-or-nothing).""" now = time.monotonic() - # If window has expired, reset it - if window_start is not None and (now - window_start) >= window_seconds: - spent = 0.0 - window_start = None - if spent + amount > max_usd: - return False - self._state[budget_name] = ( - spent + amount, - window_start if window_start is not None else now, - ) - return True + counters = self._state.setdefault(budget_name, {}) + + # Phase 1: compute effective current spend (apply window resets to a temp view). + effective: dict[str, float] = {} + for counter in amounts: + spent, window_start = counters.get(counter, (0.0, None)) + window_s = windows[counter] + if window_start is not None and (now - window_start) >= window_s: + spent = 0.0 + effective[counter] = spent + + # Phase 2: check limits in deterministic order. + for counter in _CHECK_ORDER: + if counter not in amounts: + continue + limit = limits.get(counter) + if limit is not None and effective[counter] + amounts[counter] > limit: + return False, counter + + # Phase 3: commit — increment all counters. + for counter, amount in amounts.items(): + prev_spent, window_start = counters.get(counter, (0.0, None)) + window_s = windows[counter] + # Apply window reset if expired. + if window_start is not None and (now - window_start) >= window_s: + prev_spent = 0.0 + window_start = None + new_window_start = window_start if window_start is not None else now + counters[counter] = (prev_spent + amount, new_window_start) + + return True, None + + def get_state(self, budget_name: str) -> dict[str, float]: + """Return current-window spent amount for each counter.""" + counters = self._state.get(budget_name, {}) + result: dict[str, float] = {} + for counter, (spent, _window_start) in counters.items(): + result[counter] = spent + return result + + def get_window_info(self, budget_name: str) -> dict[str, tuple[float, float | None]]: + """Return {counter: (spent, window_start)} for observability.""" + return dict(self._state.get(budget_name, {})) def reset(self, budget_name: str) -> None: self._state.pop(budget_name, None) class TemporalBudget(Budget): - """Rolling-window budget that resets after each window_seconds period.""" + """Rolling-window budget that resets after each window_seconds period. + + Supports multiple simultaneous caps (usd, llm_calls, tool_calls) each + with their own independent rolling window. + + Two configuration forms — never mixed: + + Spec string (per-cap windows):: + + budget("$5/hr + 100 calls/hr", name="api") + + Kwargs (single shared window):: + + budget(max_usd=5.0, max_llm_calls=100, window_seconds=3600, name="api") + """ def __init__( self, - max_usd: float, - window_seconds: float, + max_usd: float | None = None, + window_seconds: float | None = None, *, name: str, backend: TemporalBudgetBackend | None = None, + caps: list[tuple[str, float | None, float]] | None = None, **kwargs: Any, ) -> None: if not name: raise ValueError("TemporalBudget requires a non-empty name=") - super().__init__(max_usd=max_usd, name=name, **kwargs) - self._window_seconds = window_seconds + + # Extract multi-cap kwargs that should NOT be passed to Budget + # (Budget enforces them cumulatively; TemporalBudget enforces via backend). + max_llm_calls: int | None = kwargs.pop("max_llm_calls", None) + max_tool_calls: int | None = kwargs.pop("max_tool_calls", None) + + # Resolve effective max_usd from caps if not explicitly provided, + # so parent Budget.max_usd is set correctly without post-init override. + effective_max_usd = max_usd + if effective_max_usd is None and caps is not None: + for counter, limit, _ in caps: + if counter == "usd" and limit is not None: + effective_max_usd = limit + break + + # Pass max_usd to parent for .spent / .remaining / .limit property tracking. + super().__init__(max_usd=effective_max_usd, name=name, **kwargs) self._backend: TemporalBudgetBackend = backend or InMemoryBackend() + if caps is not None: + # Structured caps from factory (spec-string form). + self._caps: dict[str, tuple[float | None, float]] = { + counter: (limit, window_s) for counter, limit, window_s in caps + } + # Note: effective_max_usd already passed to super().__init__() above + else: + # Build caps from kwargs. + if window_seconds is None: + raise ValueError("TemporalBudget requires window_seconds (or use spec-string form)") + self._caps = {} + if max_usd is not None: + self._caps["usd"] = (max_usd, window_seconds) + if max_llm_calls is not None: + self._caps["llm_calls"] = (float(max_llm_calls), window_seconds) + if max_tool_calls is not None: + self._caps["tool_calls"] = (float(max_tool_calls), window_seconds) + if not self._caps: + raise ValueError( + "TemporalBudget requires at least one cap (max_usd, max_llm_calls, etc.)" + ) + + # Longest window in caps — used for legacy _window_seconds attribute. + self._window_seconds: float = ( + max(v[1] for v in self._caps.values()) if self._caps else (window_seconds or 3600.0) + ) + def _check_temporal_ancestor(self) -> None: """Raise ValueError if any ancestor budget in the current stack is a TemporalBudget.""" from shekel import _context @@ -130,55 +327,100 @@ def _check_temporal_ancestor(self) -> None: current = current.parent def _lazy_window_reset(self) -> None: - """If window has expired, emit on_window_reset event.""" + """If the primary window has expired since last entry, emit on_window_reset.""" budget_name = self.name or "unnamed" - spent, window_start = self._backend.get_state(budget_name) - if window_start is None: + + # Use get_window_info if available (InMemoryBackend exposes it). + if not hasattr(self._backend, "get_window_info"): return + + info = self._backend.get_window_info(budget_name) + if not info: + return + now = time.monotonic() - if (now - window_start) >= self._window_seconds: - try: - from shekel.integrations import AdapterRegistry - - AdapterRegistry.emit_event( - "on_window_reset", - { - "budget_name": self.name, - "window_seconds": self._window_seconds, - "previous_spent": spent, - }, - ) - except Exception: # noqa: BLE001 — adapter must never crash user code - pass + # Check primary cap (usd if present, else first cap). + primary = "usd" if "usd" in info else next(iter(info)) + spent, window_start = info[primary] + _, primary_window_s = self._caps.get(primary, (None, self._window_seconds)) + + if window_start is None: + return + if (now - window_start) < primary_window_s: + return + + try: + from shekel.integrations import AdapterRegistry + + AdapterRegistry.emit_event( + "on_window_reset", + { + "budget_name": self.name, + "window_seconds": primary_window_s, + "previous_spent": spent, + }, + ) + except Exception: # noqa: BLE001 — adapter must never crash user code + pass def _record_spend(self, cost: float, model: str, tokens: dict[str, int]) -> None: - """Override to enforce rolling-window spend before calling parent.""" + """Override to enforce rolling-window spend via backend before calling parent.""" budget_name = self.name or "unnamed" - max_usd = self._effective_limit - if max_usd is not None: + + # Build amounts/limits/windows for backend call. + # Only include LLM-relevant counters (usd + llm_calls). + amounts: dict[str, float] = {} + limits: dict[str, float | None] = {} + windows: dict[str, float] = {} + + for counter, (limit, window_s) in self._caps.items(): + if counter == "usd": + amounts[counter] = cost + limits[counter] = limit + windows[counter] = window_s + elif counter == "llm_calls": + amounts[counter] = 1.0 + limits[counter] = limit + windows[counter] = window_s + # tool_calls are handled separately via _record_tool_call / _check_tool_limit + + if amounts: + # Gather pre-call state for error payload (window_spent, retry_after). now = time.monotonic() - current_spent, window_start = self._backend.get_state(budget_name) - # Compute window-aware current spend (for error payload) - if window_start is not None and (now - window_start) >= self._window_seconds: - # Window has expired — treat as fresh - current_spent = 0.0 - window_start = None + pre_state: dict[str, tuple[float, float | None]] = {} + if hasattr(self._backend, "get_window_info"): + pre_state = self._backend.get_window_info(budget_name) - accepted = self._backend.check_and_add(budget_name, cost, max_usd, self._window_seconds) - if not accepted: - # Compute retry_after: time remaining in current window + allowed, exceeded = self._backend.check_and_add(budget_name, amounts, limits, windows) + if not allowed: + # Compute retry_after and window_spent for the exceeded counter. retry_after: float | None = None - if window_start is not None: - elapsed = now - window_start - retry_after = max(0.0, self._window_seconds - elapsed) + window_spent: float | None = None + if exceeded and exceeded in pre_state: + prev_spent, window_start = pre_state[exceeded] + window_s = windows.get(exceeded, self._window_seconds) + # Check if window had already expired before this call. + if window_start is not None and (now - window_start) < window_s: + elapsed = now - window_start + retry_after = max(0.0, window_s - elapsed) + window_spent = prev_spent + else: + window_spent = 0.0 # fresh window + + exc_limit = limits.get(exceeded or "usd") or 0.0 + exc_spent = ( + pre_state.get(exceeded or "usd", (0.0, None))[0] if exceeded else 0.0 + ) + (amounts.get(exceeded or "usd", 0.0)) raise BudgetExceededError( - spent=current_spent + cost, - limit=max_usd, + spent=exc_spent, + limit=exc_limit, model=model, tokens=tokens, retry_after=retry_after, - window_spent=current_spent, + window_spent=window_spent, + exceeded_counter=exceeded, ) + super()._record_spend(cost, model, tokens) def __enter__(self) -> TemporalBudget: diff --git a/shekel/backends/__init__.py b/shekel/backends/__init__.py new file mode 100644 index 0000000..f657a22 --- /dev/null +++ b/shekel/backends/__init__.py @@ -0,0 +1,6 @@ +"""Distributed budget backends for Shekel. + +Optional backends that require extra dependencies: + + pip install shekel[redis] # RedisBackend, AsyncRedisBackend +""" diff --git a/shekel/backends/redis.py b/shekel/backends/redis.py new file mode 100644 index 0000000..7d29cc5 --- /dev/null +++ b/shekel/backends/redis.py @@ -0,0 +1,479 @@ +"""Redis-backed TemporalBudgetBackend for distributed budget enforcement. + +Requires the 'redis' optional dependency:: + + pip install shekel[redis] + +Usage:: + + from shekel import budget + from shekel.backends.redis import RedisBackend + + backend = RedisBackend() # reads REDIS_URL from env + backend = RedisBackend(url="redis://...") # explicit URL + + with budget("$5/hr", name="api", backend=backend): + run_agent() + +Features: +- Atomic all-or-nothing enforcement via Lua script (one round-trip). +- Lazy connection on first use; connection pool for reuse. +- Per-counter independent rolling windows. +- Fail-closed (default) or fail-open on backend unavailability. +- Circuit breaker: stops calling Redis after N consecutive errors. +- BudgetConfigMismatchError when a budget name is already registered + with different limits/windows. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import time +from typing import Any + +from shekel.exceptions import BudgetConfigMismatchError, BudgetExceededError + +# --------------------------------------------------------------------------- +# Lua script for atomic check-and-add. +# +# Key layout (one Redis hash per budget): +# shekel:tb:{name} +# spec_hash → "" (hex digest of config; mismatch detection) +# {counter}:max → "5.0" (limit, or "" for uncapped) +# {counter}:window_s → "3600" +# {counter}:start → "1234567890000" (ms, from Redis TIME) +# {counter}:spent → "2.34" +# +# Return values from Lua: +# {1, ""} allowed +# {0, counter} exceeded: counter name +# {-2, "spec_mismatch"} config mismatch detected +# --------------------------------------------------------------------------- +_LUA_SCRIPT = """ +local key = KEYS[1] +local spec_hash = ARGV[1] +local n = tonumber(ARGV[2]) -- number of counters +local MISMATCH_SENTINEL = -2 + +-- Check spec hash (mismatch detection). +local stored_hash = redis.call('HGET', key, 'spec_hash') +if stored_hash and stored_hash ~= '' and stored_hash ~= spec_hash then + return {MISMATCH_SENTINEL, 'spec_mismatch'} +end + +-- Fetch Redis server time (milliseconds). +local t = redis.call('TIME') +local now_ms = tonumber(t[1]) * 1000 + math.floor(tonumber(t[2]) / 1000) + +-- Counters are passed as triplets: name, amount, limit (''/nil=uncapped), window_s +-- starting at ARGV[3]. +local offset = 3 + +-- Phase 1: check all limits. +for i = 1, n do + local counter = ARGV[offset] + local amount = tonumber(ARGV[offset + 1]) + local limit_str = ARGV[offset + 2] + local window_ms = tonumber(ARGV[offset + 3]) * 1000 + offset = offset + 4 + + -- Window reset? + local start_ms = tonumber(redis.call('HGET', key, counter .. ':start') or '0') or 0 + local spent = tonumber(redis.call('HGET', key, counter .. ':spent') or '0') or 0 + if start_ms > 0 and (now_ms - start_ms) >= window_ms then + spent = 0 + end + + -- Limit check ('' or missing = uncapped). + if limit_str ~= '' then + local limit = tonumber(limit_str) + if limit and spent + amount > limit then + return {0, counter} + end + end +end + +-- Phase 2: commit all counters. +offset = 3 +local max_window_ms = 0 +for i = 1, n do + local counter = ARGV[offset] + local amount = tonumber(ARGV[offset + 1]) + local limit_str = ARGV[offset + 2] + local window_ms = tonumber(ARGV[offset + 3]) * 1000 + offset = offset + 4 + + if window_ms > max_window_ms then max_window_ms = window_ms end + + local start_ms = tonumber(redis.call('HGET', key, counter .. ':start') or '0') or 0 + local spent = tonumber(redis.call('HGET', key, counter .. ':spent') or '0') or 0 + if start_ms == 0 or (now_ms - start_ms) >= window_ms then + -- fresh window + redis.call('HSET', key, counter .. ':start', now_ms) + redis.call('HSET', key, counter .. ':spent', amount) + else + redis.call('HINCRBYFLOAT', key, counter .. ':spent', amount) + end + redis.call('HSET', key, counter .. ':max', limit_str) + redis.call('HSET', key, counter .. ':window_s', tonumber(ARGV[offset - 1])) +end + +-- Store spec hash and set TTL = 2x max window. +redis.call('HSET', key, 'spec_hash', spec_hash) +redis.call('PEXPIRE', key, max_window_ms * 2) + +return {1, ''} +""" + + +def _build_spec_hash( + limits: dict[str, float | None], + windows: dict[str, float], +) -> str: + """Stable hex hash of {counter: (limit, window_s)} for mismatch detection.""" + payload = {k: (limits.get(k), windows.get(k)) for k in sorted(limits)} + return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()[:16] + + +def _build_argv( + spec_hash: str, + amounts: dict[str, float], + limits: dict[str, float | None], + windows: dict[str, float], +) -> list[str]: + """Build the ARGV list for the Lua script.""" + argv: list[str] = [spec_hash, str(len(amounts))] + for counter, amount in amounts.items(): + limit = limits.get(counter) + argv += [ + counter, + str(amount), + str(limit) if limit is not None else "", + str(windows[counter]), + ] + return argv + + +def _emit_unavailable(budget_name: str, error: Exception) -> None: + """Emit on_backend_unavailable event to all registered adapters.""" + try: + from shekel.integrations import AdapterRegistry + + AdapterRegistry.emit_event( + "on_backend_unavailable", + {"budget_name": budget_name, "error": str(error)}, + ) + except Exception: # noqa: BLE001 + pass + + +class RedisBackend: + """Synchronous Redis-backed rolling-window budget backend. + + Args: + url: Redis URL (e.g. ``redis://user:pass@host:6379/0``). + If omitted, reads ``REDIS_URL`` from the environment. + tls: Force TLS (sets ``ssl=True`` on the Redis connection). + on_unavailable: ``"closed"`` (default) — raise BudgetExceededError + when Redis is unreachable. ``"open"`` — allow the call through. + circuit_breaker_threshold: Consecutive errors before opening the + circuit breaker. Default 3. + circuit_breaker_cooldown: Seconds to wait before retrying after + circuit opens. Default 10. + """ + + def __init__( + self, + url: str | None = None, + tls: bool = False, + on_unavailable: str = "closed", + circuit_breaker_threshold: int = 3, + circuit_breaker_cooldown: float = 10.0, + ) -> None: + self._url = url or os.environ.get("REDIS_URL", "redis://127.0.0.1:6379/0") + self._tls = tls + self._on_unavailable = on_unavailable + self._cb_threshold = circuit_breaker_threshold + self._cb_cooldown = circuit_breaker_cooldown + + self._client: Any = None # lazily created + self._script_sha: str | None = None + self._consecutive_errors: int = 0 + self._circuit_open_at: float | None = None + + # ------------------------------------------------------------------ + # Lazy connection + # ------------------------------------------------------------------ + + def _ensure_client(self) -> Any: + if self._client is None: + try: + import redis as redis_lib # noqa: PLC0415 + + kwargs: dict[str, Any] = {"decode_responses": False} + if self._tls: + kwargs["ssl"] = True + self._client = redis_lib.Redis.from_url(self._url, **kwargs) + except ImportError as exc: # pragma: no cover — only reached without redis installed + raise ImportError( + "RedisBackend requires 'redis': pip install shekel[redis]" + ) from exc + return self._client + + def _ensure_script(self) -> str: + if self._script_sha is None: + client = self._ensure_client() + self._script_sha = client.script_load(_LUA_SCRIPT) + return self._script_sha + + # ------------------------------------------------------------------ + # Circuit breaker helpers + # ------------------------------------------------------------------ + + def _is_circuit_open(self) -> bool: + if self._circuit_open_at is None: + return False + if time.monotonic() - self._circuit_open_at >= self._cb_cooldown: + # Cooldown elapsed — close the circuit and let the next call try. + self._circuit_open_at = None + self._consecutive_errors = 0 + return False + return True + + def _record_error(self) -> None: + self._consecutive_errors += 1 + if self._consecutive_errors >= self._cb_threshold: + self._circuit_open_at = time.monotonic() + + def _record_success(self) -> None: + self._consecutive_errors = 0 + self._circuit_open_at = None + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def check_and_add( + self, + budget_name: str, + amounts: dict[str, float], + limits: dict[str, float | None], + windows: dict[str, float], + ) -> tuple[bool, str | None]: + key = f"shekel:tb:{budget_name}" + spec_hash = _build_spec_hash(limits, windows) + argv = _build_argv(spec_hash, amounts, limits, windows) + + if self._is_circuit_open(): + return self._handle_unavailable(budget_name, RuntimeError("Circuit breaker open")) + + try: + sha = self._ensure_script() + result = self._ensure_client().evalsha(sha, 1, key, *argv) + self._record_success() + except Exception as exc: + self._record_error() + _emit_unavailable(budget_name, exc) + return self._handle_unavailable(budget_name, exc) + + status = int(result[0]) + counter_bytes = result[1] + counter = counter_bytes.decode() if isinstance(counter_bytes, bytes) else str(counter_bytes) + + if status == -2: + raise BudgetConfigMismatchError( + f"Budget {budget_name!r} already registered with different limits/windows. " + "Call backend.reset(budget_name) to clear the existing state." + ) + if status == 0: + return False, counter or None + return True, None + + def _handle_unavailable(self, budget_name: str, exc: Exception) -> tuple[bool, str | None]: + if self._on_unavailable == "open": + return True, None + raise BudgetExceededError( + spent=0.0, + limit=0.0, + model="unknown", + exceeded_counter="backend_unavailable", + ) from exc + + def get_state(self, budget_name: str) -> dict[str, float]: + key = f"shekel:tb:{budget_name}" + try: + raw = self._ensure_client().hgetall(key) + except ( + Exception + ): # noqa: BLE001 — get_state is best-effort; Redis unavailable → empty state + return {} + result: dict[str, float] = {} + for field_bytes, val_bytes in raw.items(): + field = field_bytes.decode() if isinstance(field_bytes, bytes) else str(field_bytes) + if field.endswith(":spent"): + counter = field[: -len(":spent")] + try: + result[counter] = float(val_bytes) + except (ValueError, TypeError): + pass # corrupt/non-numeric value — skip + return result + + def reset(self, budget_name: str) -> None: + key = f"shekel:tb:{budget_name}" + self._ensure_client().delete(key) + + def close(self) -> None: + if self._client is not None: + self._client.close() + + +class AsyncRedisBackend: + """Async Redis-backed rolling-window budget backend. + + Same semantics as :class:`RedisBackend` but all public methods are + coroutines — suitable for FastAPI, LangGraph, and other async contexts. + + Args: + url: Redis URL. If omitted, reads ``REDIS_URL`` from the environment. + tls: Force TLS. + on_unavailable: ``"closed"`` (default) or ``"open"``. + circuit_breaker_threshold: Consecutive errors before opening circuit. + circuit_breaker_cooldown: Seconds before retrying after circuit opens. + """ + + def __init__( + self, + url: str | None = None, + tls: bool = False, + on_unavailable: str = "closed", + circuit_breaker_threshold: int = 3, + circuit_breaker_cooldown: float = 10.0, + ) -> None: + self._url = url or os.environ.get("REDIS_URL", "redis://127.0.0.1:6379/0") + self._tls = tls + self._on_unavailable = on_unavailable + self._cb_threshold = circuit_breaker_threshold + self._cb_cooldown = circuit_breaker_cooldown + + self._client: Any = None + self._script_sha: str | None = None + self._consecutive_errors: int = 0 + self._circuit_open_at: float | None = None + + async def _ensure_client(self) -> Any: + if self._client is None: + try: + import redis.asyncio as aioredis # noqa: PLC0415 + + kwargs: dict[str, Any] = {"decode_responses": False} + if self._tls: + kwargs["ssl"] = True + self._client = aioredis.Redis.from_url(self._url, **kwargs) + except ImportError as exc: # pragma: no cover — only reached without redis installed + raise ImportError( + "AsyncRedisBackend requires 'redis[asyncio]': pip install shekel[redis]" + ) from exc + return self._client + + async def _ensure_script(self) -> str: + if self._script_sha is None: + client = await self._ensure_client() + self._script_sha = await client.script_load(_LUA_SCRIPT) + return self._script_sha + + def _is_circuit_open(self) -> bool: + if self._circuit_open_at is None: + return False + if time.monotonic() - self._circuit_open_at >= self._cb_cooldown: + self._circuit_open_at = None + self._consecutive_errors = 0 + return False + return True + + def _record_error(self) -> None: + self._consecutive_errors += 1 + if self._consecutive_errors >= self._cb_threshold: + self._circuit_open_at = time.monotonic() + + def _record_success(self) -> None: + self._consecutive_errors = 0 + self._circuit_open_at = None + + async def check_and_add( + self, + budget_name: str, + amounts: dict[str, float], + limits: dict[str, float | None], + windows: dict[str, float], + ) -> tuple[bool, str | None]: + key = f"shekel:tb:{budget_name}" + spec_hash = _build_spec_hash(limits, windows) + argv = _build_argv(spec_hash, amounts, limits, windows) + + if self._is_circuit_open(): + return await self._handle_unavailable(budget_name, RuntimeError("Circuit breaker open")) + + try: + sha = await self._ensure_script() + client = await self._ensure_client() + result = await client.evalsha(sha, 1, key, *argv) + self._record_success() + except Exception as exc: + self._record_error() + _emit_unavailable(budget_name, exc) + return await self._handle_unavailable(budget_name, exc) + + status = int(result[0]) + counter_bytes = result[1] + counter = counter_bytes.decode() if isinstance(counter_bytes, bytes) else str(counter_bytes) + + if status == -2: + raise BudgetConfigMismatchError( + f"Budget {budget_name!r} already registered with different limits/windows." + ) + if status == 0: + return False, counter or None + return True, None + + async def _handle_unavailable( + self, budget_name: str, exc: Exception + ) -> tuple[bool, str | None]: + if self._on_unavailable == "open": + return True, None + raise BudgetExceededError( + spent=0.0, + limit=0.0, + model="unknown", + exceeded_counter="backend_unavailable", + ) from exc + + async def get_state(self, budget_name: str) -> dict[str, float]: + key = f"shekel:tb:{budget_name}" + try: + client = await self._ensure_client() + raw = await client.hgetall(key) + except ( + Exception + ): # noqa: BLE001 — get_state is best-effort; Redis unavailable → empty state + return {} + result: dict[str, float] = {} + for field_bytes, val_bytes in raw.items(): + field = field_bytes.decode() if isinstance(field_bytes, bytes) else str(field_bytes) + if field.endswith(":spent"): + counter = field[: -len(":spent")] + try: + result[counter] = float(val_bytes) + except (ValueError, TypeError): + pass # corrupt/non-numeric value — skip + return result + + async def reset(self, budget_name: str) -> None: + key = f"shekel:tb:{budget_name}" + client = await self._ensure_client() + await client.delete(key) + + async def close(self) -> None: + if self._client is not None: + await self._client.aclose() diff --git a/shekel/exceptions.py b/shekel/exceptions.py index 18f89b2..8b665f8 100644 --- a/shekel/exceptions.py +++ b/shekel/exceptions.py @@ -1,6 +1,15 @@ from __future__ import annotations +class BudgetConfigMismatchError(Exception): + """Raised when a distributed budget backend detects a config mismatch. + + Occurs when a budget name is already registered in the backend with + different limits or window settings than the current configuration. + Call reset() on the backend to clear the existing state. + """ + + class ToolBudgetExceededError(Exception): """Raised when tool invocations exceed the configured budget limit. @@ -51,6 +60,7 @@ def __init__( tokens: dict[str, int] | None = None, retry_after: float | None = None, window_spent: float | None = None, + exceeded_counter: str | None = None, ) -> None: self.spent = spent self.limit = limit @@ -58,6 +68,7 @@ def __init__( self.tokens: dict[str, int] = tokens if tokens is not None else {"input": 0, "output": 0} self.retry_after: float | None = retry_after self.window_spent: float | None = window_spent + self.exceeded_counter: str | None = exceeded_counter super().__init__(str(self)) def __str__(self) -> str: @@ -71,9 +82,113 @@ def __str__(self) -> str: ) else: last_call = f" Last call: {self.model}\n" + counter_note = f" Counter: {self.exceeded_counter}\n" if self.exceeded_counter else "" return ( f"Budget of ${self.limit:.2f} exceeded (${self.spent:.4f} spent)\n" f"{last_call}" + f"{counter_note}" f" Tip: Increase max_usd, add warn_at=0.8 for an early warning, " f"or add fallback='gpt-4o-mini' to switch to a cheaper model instead of raising." ) + + +class NodeBudgetExceededError(BudgetExceededError): + """Raised when a LangGraph node exceeds its budget cap. + + Raised *before* the node body executes when an explicit cap is set, + or during execution when the parent budget is exhausted. + """ + + def __init__(self, node_name: str, spent: float, limit: float) -> None: + self.node_name = node_name + super().__init__(spent=spent, limit=limit, model=f"node:{node_name}") + + def __str__(self) -> str: + return ( + f"Node budget exceeded for '{self.node_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f})\n" + f" Tip: Increase b.node('{self.node_name}', max_usd=...) " + f"or remove the explicit cap to use the parent budget only.\n" + f" Run b.tree() for full spend breakdown." + ) + + +class AgentBudgetExceededError(BudgetExceededError): + """Raised when an agent exceeds its budget cap (CrewAI, OpenClaw).""" + + def __init__(self, agent_name: str, spent: float, limit: float) -> None: + self.agent_name = agent_name + super().__init__(spent=spent, limit=limit, model=f"agent:{agent_name}") + + def __str__(self) -> str: + return ( + f"Agent budget exceeded for '{self.agent_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f})\n" + f" Tip: Increase b.agent('{self.agent_name}', max_usd=...) " + f"or remove the explicit cap to use the parent budget only.\n" + f" Run b.tree() for full spend breakdown." + ) + + +class TaskBudgetExceededError(BudgetExceededError): + """Raised when a task exceeds its budget cap (CrewAI). + + Raised *before* the task executes when an explicit cap is set. + """ + + def __init__(self, task_name: str, spent: float, limit: float) -> None: + self.task_name = task_name + super().__init__(spent=spent, limit=limit, model=f"task:{task_name}") + + def __str__(self) -> str: + return ( + f"Task budget exceeded for '{self.task_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f})\n" + f" Tip: Increase b.task('{self.task_name}', max_usd=...) " + f"or remove the explicit cap to use the parent budget only.\n" + f" Run b.tree() for full spend breakdown." + ) + + +class ChainBudgetExceededError(BudgetExceededError): + """Raised when a LangChain chain or runnable exceeds its budget cap. + + Raised *before* the chain body executes when an explicit cap is set, + or during execution when the parent budget is exhausted. + """ + + def __init__(self, chain_name: str, spent: float, limit: float) -> None: + self.chain_name = chain_name + super().__init__(spent=spent, limit=limit, model=f"chain:{chain_name}") + + def __str__(self) -> str: + return ( + f"Chain budget exceeded for '{self.chain_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f})\n" + f" Tip: Increase b.chain('{self.chain_name}', max_usd=...) " + f"or remove the explicit cap to use the parent budget only.\n" + f" Run b.tree() for full spend breakdown." + ) + + +class SessionBudgetExceededError(BudgetExceededError): + """Raised when an always-on agent session exceeds its rolling-window budget (OpenClaw).""" + + def __init__( + self, + agent_name: str, + spent: float, + limit: float, + window: float | None = None, + ) -> None: + self.agent_name = agent_name + self.window = window + super().__init__(spent=spent, limit=limit, model=f"session:{agent_name}") + + def __str__(self) -> str: + window_str = f" over {self.window:.0f}s window" if self.window is not None else "" + return ( + f"Session budget exceeded for agent '{self.agent_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f}{window_str})\n" + f" Tip: Increase the session budget or use a longer rolling window." + ) diff --git a/shekel/integrations/base.py b/shekel/integrations/base.py index 6a46bd1..0cc5f5b 100644 --- a/shekel/integrations/base.py +++ b/shekel/integrations/base.py @@ -137,6 +137,19 @@ def on_tool_budget_exceeded(self, error_data: dict[str, Any]) -> None: """ pass + def on_backend_unavailable(self, error_data: dict[str, Any]) -> None: + """Called when a distributed budget backend is unreachable or errors. + + Fired before raising BudgetExceededError (fail-closed) or allowing + the call through (fail-open), depending on on_unavailable setting. + + Args: + error_data: Dictionary containing: + - budget_name: str - Name of the budget whose backend failed + - error: str - String description of the error + """ + pass # pragma: no cover — base stub; overridden by concrete adapters + def on_tool_warn(self, warn_data: dict[str, Any]) -> None: """Called when tool calls reach the warn_at threshold. diff --git a/shekel/providers/crewai.py b/shekel/providers/crewai.py index e1948e1..70c9901 100644 --- a/shekel/providers/crewai.py +++ b/shekel/providers/crewai.py @@ -1,7 +1,8 @@ -"""CrewAI provider adapter for Shekel tool budget tracking.""" +"""CrewAI provider adapter for Shekel tool budget tracking and agent/task circuit breaking.""" from __future__ import annotations +import warnings from typing import Any _original_run: Any = None @@ -74,3 +75,168 @@ def remove_patches(self) -> None: except (ImportError, AttributeError, TypeError): # pragma: no cover _original_run = None # pragma: no cover _original_arun = None # pragma: no cover + + +# --------------------------------------------------------------------------- +# CrewAI execution-level adapter — agent/task circuit breaking (v0.3.1) +# --------------------------------------------------------------------------- + +_execution_patch_refcount: int = 0 +_original_execute_task: Any = None + + +def _find_agent_cap(agent_name: str, active: Any) -> Any: + """Walk the budget parent chain to find a registered agent cap.""" + b: Any = active + while b is not None: + cb = b._agent_budgets.get(agent_name) + if cb is not None: + return cb + b = b.parent + return None + + +def _find_task_cap(task_name: str, active: Any) -> Any: + """Walk the budget parent chain to find a registered task cap.""" + b: Any = active + while b is not None: + cb = b._task_budgets.get(task_name) + if cb is not None: + return cb + b = b.parent + return None + + +def _has_any_task_caps(active: Any) -> bool: + """Return True if any budget in the parent chain has registered task caps.""" + b: Any = active + while b is not None: + if b._task_budgets: + return True + b = b.parent + return False + + +def _get_task_name(task: Any) -> str: + """Resolve task name: task.name (non-empty) → task.description (non-empty) → ''.""" + return getattr(task, "name", None) or getattr(task, "description", None) or "" + + +def _gate_execution(agent_name: str, task_name: str, task: Any, active: Any) -> None: + """Pre-execution gate: task cap → agent cap → global budget.""" + from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError + + # Warn when task.name is absent/empty and task caps are registered (silent-miss risk) + task_has_name = bool(getattr(task, "name", None)) + if not task_has_name and _has_any_task_caps(active): + desc = getattr(task, "description", "") or "" + if desc: + msg = ( + f"shekel: task has no name (description: '{desc[:50]}...') " + "— set task.name to apply caps." + ) + else: + msg = "shekel: task has no name and no description — set task.name to apply caps." + warnings.warn(msg, UserWarning, stacklevel=4) + + # Task cap — most specific, checked first + task_cb = _find_task_cap(task_name, active) + if task_cb is not None and task_cb._spent >= task_cb.max_usd: + raise TaskBudgetExceededError( + task_name=task_name, spent=task_cb._spent, limit=task_cb.max_usd + ) + + # Agent cap + agent_cb = _find_agent_cap(agent_name, active) + if agent_cb is not None and agent_cb._spent >= agent_cb.max_usd: + raise AgentBudgetExceededError( + agent_name=agent_name, spent=agent_cb._spent, limit=agent_cb.max_usd + ) + + # Global budget check (mirrors langgraph._gate pattern) + if active._effective_limit is not None and active._spent >= active._effective_limit: + raise AgentBudgetExceededError( + agent_name=agent_name, spent=active._spent, limit=active._effective_limit + ) + + +def _attribute_execution_spend( + agent_name: str, task_name: str, active: Any, spend_before: float +) -> None: + """Attribute spend delta to both agent and task ComponentBudgets.""" + delta = active._spent - spend_before + if delta <= 0: + return + agent_cb = _find_agent_cap(agent_name, active) + if agent_cb is not None: + agent_cb._spent += delta + task_cb = _find_task_cap(task_name, active) + if task_cb is not None: + task_cb._spent += delta + + +class CrewAIExecutionAdapter: + """Patches ``Agent.execute_task`` for agent/task-level budget circuit breaking. + + Activated transparently by ``ShekelRuntime.probe()`` on ``budget().__enter__()``. + A reference counter ensures nested budgets don't double-patch or prematurely + restore the original method. + """ + + def install_patches(self, budget: Any) -> None: # noqa: ARG002 + """Patch ``Agent.execute_task``. Raises ``ImportError`` when crewai + is not installed so that ``ShekelRuntime.probe()`` silently skips it. + """ + global _execution_patch_refcount, _original_execute_task + + import crewai.agent # raises ImportError if crewai not installed # noqa: F401 + from crewai.agent import Agent + + _execution_patch_refcount += 1 + if _execution_patch_refcount > 1: + return # already patched — just increment the refcount + + orig = Agent.execute_task + _original_execute_task = orig + + def _patched_execute_task( + self: Any, task: Any, context: Any = None, tools: Any = None + ) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + if active is None: + return orig(self, task, context, tools) + + agent_name = getattr(self, "role", str(self)) + task_name = _get_task_name(task) + + _gate_execution(agent_name, task_name, task, active) + spend_before = active._spent + result = orig(self, task, context, tools) + _attribute_execution_spend(agent_name, task_name, active, spend_before) + return result + + Agent.execute_task = _patched_execute_task + + def remove_patches(self, budget: Any) -> None: # noqa: ARG002 + """Restore ``Agent.execute_task``. Only restores when the last + active budget closes (reference count reaches zero). + """ + global _execution_patch_refcount, _original_execute_task + + if _execution_patch_refcount <= 0: + return + _execution_patch_refcount -= 1 + if _execution_patch_refcount > 0: + return # other budgets still active + + if _original_execute_task is None: + return # pragma: no cover — defensive null check + try: + from crewai.agent import Agent + + Agent.execute_task = _original_execute_task + except ImportError: # pragma: no cover — defensive cleanup + pass + _original_execute_task = None # reset after restore (langchain.py pattern) diff --git a/shekel/providers/langchain.py b/shekel/providers/langchain.py index 17147b8..787992b 100644 --- a/shekel/providers/langchain.py +++ b/shekel/providers/langchain.py @@ -1,12 +1,22 @@ -"""LangChain / LangGraph provider adapter for Shekel tool budget tracking.""" +"""LangChain provider adapter for Shekel — tool budget tracking and chain-level circuit breaking.""" from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from shekel._budget import Budget _original_invoke: Any = None _original_ainvoke: Any = None +# Chain-level patch state (LangChainRunnerAdapter) +_chain_patch_refcount: int = 0 +_original_call_with_config: Any = None +_original_acall_with_config: Any = None +_original_sequence_invoke: Any = None +_original_sequence_ainvoke: Any = None + def _get_price(budget: Any, tool_name: str) -> float: if budget.tool_prices is not None and tool_name in budget.tool_prices: @@ -74,3 +84,190 @@ def remove_patches(self) -> None: except (ImportError, AttributeError, TypeError): # pragma: no cover _original_invoke = None # pragma: no cover _original_ainvoke = None # pragma: no cover + + +# --------------------------------------------------------------------------- +# Chain-level helpers +# --------------------------------------------------------------------------- + + +def _find_chain_cap(chain_name: str, active: Budget) -> Any: + """Walk the budget parent chain to find a registered chain cap. + + Returns the first ``ComponentBudget`` whose ``_chain_budgets`` contains + ``chain_name``, starting from ``active`` and walking toward the root. + Returns ``None`` if no ancestor has a cap for this chain. + """ + b: Any = active + while b is not None: + cb = b._chain_budgets.get(chain_name) + if cb is not None: + return cb + b = b.parent + return None + + +def _gate_chain(chain_name: str | None, active: Budget) -> None: + """Pre-execution budget check for a named chain or runnable. + + Raises ``ChainBudgetExceededError`` if the explicit chain cap or the + parent budget is already at / over its limit. No-op when ``chain_name`` + is falsy or when no cap is registered for it. + """ + if not chain_name: # pragma: no cover — callers guard on truthiness + return + + from shekel.exceptions import ChainBudgetExceededError + + cb = _find_chain_cap(chain_name, active) + if cb is not None and cb._spent >= cb.max_usd: + raise ChainBudgetExceededError(chain_name=chain_name, spent=cb._spent, limit=cb.max_usd) + + if active._effective_limit is not None and active._spent >= active._effective_limit: + raise ChainBudgetExceededError( + chain_name=chain_name, + spent=active._spent, + limit=active._effective_limit, + ) + + +def _attribute_chain_spend(chain_name: str | None, active: Budget, spend_before: float) -> None: + """Post-execution: add the spend delta to the chain's ComponentBudget._spent.""" + if not chain_name: # pragma: no cover — callers guard on truthiness + return + cb = _find_chain_cap(chain_name, active) + if cb is not None: + delta = active._spent - spend_before + if delta > 0: + cb._spent += delta + + +# --------------------------------------------------------------------------- +# LangChainRunnerAdapter — chain-level circuit breaking +# --------------------------------------------------------------------------- + + +class LangChainRunnerAdapter: + """Patches ``Runnable._call_with_config``, ``_acall_with_config``, and + ``RunnableSequence.invoke/ainvoke`` for chain-level budget enforcement. + + Raises ``ImportError`` when ``langchain_core`` is not installed so that + ``ShekelRuntime.probe()`` silently skips it. + """ + + def install_patches(self, budget: Budget) -> None: # noqa: ARG002 + global _chain_patch_refcount, _original_call_with_config, _original_acall_with_config + global _original_sequence_invoke, _original_sequence_ainvoke + + import langchain_core.runnables.base # raises ImportError if not installed # noqa: F401 + from langchain_core.runnables.base import Runnable, RunnableSequence + + _chain_patch_refcount += 1 + if _chain_patch_refcount > 1: + return + + # -- Patch 1: Runnable._call_with_config (sync, covers RunnableLambda etc.) -- + orig_cwc = Runnable._call_with_config + _original_call_with_config = orig_cwc + + def _patched_cwc(self: Any, func: Any, input_: Any, config: Any, **kwargs: Any) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + name: str | None = getattr(self, "name", None) + if active is not None and name: + _gate_chain(name, active) + spend_before = active._spent + result = orig_cwc(self, func, input_, config, **kwargs) + _attribute_chain_spend(name, active, spend_before) + return result + return orig_cwc(self, func, input_, config, **kwargs) + + Runnable._call_with_config = _patched_cwc # type: ignore[method-assign,assignment] + + # -- Patch 2: Runnable._acall_with_config (async, covers async RunnableLambda) -- + orig_acwc = Runnable._acall_with_config + _original_acall_with_config = orig_acwc + + async def _patched_acwc( + self: Any, func: Any, input_: Any, config: Any, **kwargs: Any + ) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + name = getattr(self, "name", None) + if active is not None and name: + _gate_chain(name, active) + spend_before = active._spent + result = await orig_acwc(self, func, input_, config, **kwargs) + _attribute_chain_spend(name, active, spend_before) + return result + return await orig_acwc(self, func, input_, config, **kwargs) + + Runnable._acall_with_config = _patched_acwc # type: ignore[method-assign,assignment] + + # -- Patch 3: RunnableSequence.invoke (LCEL pipelines, sync) -- + orig_seq_invoke = RunnableSequence.invoke + _original_sequence_invoke = orig_seq_invoke + + def _patched_seq_invoke(self: Any, input: Any, config: Any = None, **kwargs: Any) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + name = getattr(self, "name", None) + if active is not None and name: + _gate_chain(name, active) + spend_before = active._spent + result = orig_seq_invoke(self, input, config, **kwargs) + _attribute_chain_spend(name, active, spend_before) + return result + return orig_seq_invoke(self, input, config, **kwargs) + + RunnableSequence.invoke = _patched_seq_invoke # type: ignore[method-assign] + + # -- Patch 4: RunnableSequence.ainvoke (LCEL pipelines, async) -- + orig_seq_ainvoke = RunnableSequence.ainvoke + _original_sequence_ainvoke = orig_seq_ainvoke + + async def _patched_seq_ainvoke( + self: Any, input: Any, config: Any = None, **kwargs: Any + ) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + name = getattr(self, "name", None) + if active is not None and name: + _gate_chain(name, active) + spend_before = active._spent + result = await orig_seq_ainvoke(self, input, config, **kwargs) + _attribute_chain_spend(name, active, spend_before) + return result + return await orig_seq_ainvoke(self, input, config, **kwargs) + + RunnableSequence.ainvoke = _patched_seq_ainvoke # type: ignore[method-assign] + + def remove_patches(self, budget: Budget) -> None: # noqa: ARG002 + global _chain_patch_refcount, _original_call_with_config, _original_acall_with_config + global _original_sequence_invoke, _original_sequence_ainvoke + + if _chain_patch_refcount <= 0: + return + _chain_patch_refcount -= 1 + if _chain_patch_refcount > 0: + return + + if _original_call_with_config is None: # pragma: no cover — defensive null check + return + try: + from langchain_core.runnables.base import Runnable, RunnableSequence + + Runnable._call_with_config = _original_call_with_config # type: ignore[method-assign] + Runnable._acall_with_config = _original_acall_with_config # type: ignore[method-assign] + RunnableSequence.invoke = _original_sequence_invoke # type: ignore[method-assign] + RunnableSequence.ainvoke = _original_sequence_ainvoke # type: ignore[method-assign] + except ImportError: # pragma: no cover — defensive cleanup + pass + _original_call_with_config = None + _original_acall_with_config = None + _original_sequence_invoke = None + _original_sequence_ainvoke = None diff --git a/shekel/providers/langgraph.py b/shekel/providers/langgraph.py new file mode 100644 index 0000000..ef2cde0 --- /dev/null +++ b/shekel/providers/langgraph.py @@ -0,0 +1,182 @@ +"""LangGraph adapter for Shekel — node-level circuit breaking (v0.3.1). + +Patches ``StateGraph.add_node()`` transparently so every node — sync and async +— is wrapped with a pre-execution budget gate. Requires no user code changes: +just open a ``budget()`` context before building the graph. + +How it works: + +1. On ``budget.__enter__()``, ``ShekelRuntime.probe()`` calls + ``LangGraphAdapter().install_patches(budget)``. The adapter patches + ``StateGraph.add_node`` with a version that wraps each node function with + ``_make_gate()``. +2. When a node runs, the gate: + a. Checks the explicit node cap (``b.node("name", max_usd=X)``), if set. + b. Checks the parent budget total. + c. On success, records the spend delta into ``ComponentBudget._spent``. +3. On ``budget.__exit__()``, ``ShekelRuntime.release()`` calls + ``remove_patches()``. A reference counter ensures nested budgets don't + double-patch or prematurely restore the original method. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import Any + +_original_add_node: Any = None +_patch_refcount: int = 0 + + +class LangGraphAdapter: + """Patches ``StateGraph.add_node()`` for node-level budget enforcement.""" + + def install_patches(self, budget: Any) -> None: # noqa: ARG002 + """Patch ``StateGraph.add_node``. Raises ``ImportError`` when langgraph + is not installed so that ``ShekelRuntime.probe()`` silently skips it. + """ + global _original_add_node, _patch_refcount + + import langgraph.graph.state # raises ImportError if not installed # noqa: F401 + from langgraph.graph.state import StateGraph + + _patch_refcount += 1 + if _patch_refcount > 1: + return # already patched — just increment the refcount + + orig = StateGraph.add_node + _original_add_node = orig + + def _patched_add_node(self: Any, node: Any, action: Any = None, **kwargs: Any) -> Any: + if action is None and callable(node): + # add_node(fn) — fn.__name__ is the node name + node_name = getattr(node, "__name__", str(node)) + return orig(self, _make_gate(node, node_name), None, **kwargs) + if isinstance(node, str) and action is not None and callable(action): + # add_node("name", fn) + return orig(self, node, _make_gate(action, node), **kwargs) + # Passthrough: non-callable action or other invalid inputs — + # forward to LangGraph to let it raise its own validation error. + return orig(self, node, action, **kwargs) # pragma: no cover + + StateGraph.add_node = _patched_add_node # type: ignore[method-assign] + + def remove_patches(self, budget: Any) -> None: # noqa: ARG002 + """Restore ``StateGraph.add_node``. Only restores when the last + active budget closes (reference count reaches zero). + """ + global _patch_refcount + + if _patch_refcount <= 0: + return + _patch_refcount -= 1 + if _patch_refcount > 0: + return # other budgets still active + + if _original_add_node is None: + return + try: + from langgraph.graph.state import StateGraph + + StateGraph.add_node = _original_add_node # type: ignore[method-assign] + except ImportError: # pragma: no cover — defensive cleanup + pass + + +# --------------------------------------------------------------------------- +# Gate factory +# --------------------------------------------------------------------------- + + +def _make_gate(fn: Any, node_name: str) -> Any: + """Return a sync or async wrapper that gates ``fn`` with budget checks.""" + if inspect.iscoroutinefunction(fn): + return _make_async_gate(fn, node_name) + return _make_sync_gate(fn, node_name) + + +def _make_sync_gate(fn: Any, node_name: str) -> Any: + @functools.wraps(fn) + def _gated(state: Any, *args: Any, **kwargs: Any) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + if active is None: + return fn(state, *args, **kwargs) + + _gate(node_name, active) + spend_before = active._spent + result = fn(state, *args, **kwargs) + _attribute_spend(node_name, active, spend_before) + return result + + return _gated + + +def _make_async_gate(fn: Any, node_name: str) -> Any: + @functools.wraps(fn) + async def _gated(state: Any, *args: Any, **kwargs: Any) -> Any: + from shekel._context import get_active_budget + + active = get_active_budget() + if active is None: + return await fn(state, *args, **kwargs) + + _gate(node_name, active) + spend_before = active._spent + result = await fn(state, *args, **kwargs) + _attribute_spend(node_name, active, spend_before) + return result + + return _gated + + +# --------------------------------------------------------------------------- +# Gate logic and spend attribution +# --------------------------------------------------------------------------- + + +def _find_node_cap(node_name: str, active: Any) -> Any: + """Walk the budget parent chain to find a registered node cap. + + Returns the first ``ComponentBudget`` whose ``_node_budgets`` contains + ``node_name``, starting from ``active`` and walking toward the root. + Returns ``None`` if no ancestor has a cap for this node. + """ + b: Any = active + while b is not None: + cb = b._node_budgets.get(node_name) + if cb is not None: + return cb + b = b.parent + return None + + +def _gate(node_name: str, active: Any) -> None: + """Pre-execution budget check. Raises ``NodeBudgetExceededError`` if the + explicit node cap or parent budget is already at / over its limit. + """ + from shekel.exceptions import NodeBudgetExceededError + + # 1. Explicit node cap check — walk up to find the cap (highest priority) + cb = _find_node_cap(node_name, active) + if cb is not None and cb._spent >= cb.max_usd: + raise NodeBudgetExceededError(node_name=node_name, spent=cb._spent, limit=cb.max_usd) + + # 2. Parent budget total check + if active._effective_limit is not None and active._spent >= active._effective_limit: + raise NodeBudgetExceededError( + node_name=node_name, + spent=active._spent, + limit=active._effective_limit, + ) + + +def _attribute_spend(node_name: str, active: Any, spend_before: float) -> None: + """Post-execution: add the spend delta to the node's ComponentBudget._spent.""" + cb = _find_node_cap(node_name, active) + if cb is not None: + delta = active._spent - spend_before + if delta > 0: + cb._spent += delta diff --git a/tests/integrations/test_redis_docker.py b/tests/integrations/test_redis_docker.py new file mode 100644 index 0000000..7d3993b --- /dev/null +++ b/tests/integrations/test_redis_docker.py @@ -0,0 +1,419 @@ +"""Docker-based Redis integration tests for RedisBackend and AsyncRedisBackend. + +Spins up a real ``redis:alpine`` container to verify atomic enforcement, +window expiry, config mismatch detection, and the full ``budget()`` + +``RedisBackend`` flow. + +Requires: Docker daemon running, ``redis`` and ``testcontainers`` packages. +Auto-skipped when either is absent or Docker is unavailable. +""" + +from __future__ import annotations + +import time + +import pytest + +testcontainers = pytest.importorskip("testcontainers", reason="testcontainers not installed") +pytest.importorskip("redis", reason="redis not installed") + +from testcontainers.redis import RedisContainer # noqa: E402 + +from shekel.backends.redis import AsyncRedisBackend, RedisBackend # noqa: E402 +from shekel.exceptions import BudgetConfigMismatchError, BudgetExceededError # noqa: E402 + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_BUDGET = "test_budget" + + +@pytest.fixture(scope="session") +def redis_url() -> str: # type: ignore[return] + """Start a redis:alpine container for the test session; yield its URL.""" + with RedisContainer(image="redis:alpine") as container: + host = container.get_container_host_ip() + port = container.get_exposed_port(container.port) + yield f"redis://{host}:{port}/0" + + +@pytest.fixture() +def backend(redis_url: str) -> RedisBackend: # type: ignore[return] + """Fresh RedisBackend pointing at the Docker container.""" + b = RedisBackend(url=redis_url) + b.reset(_BUDGET) # clean slate before each test + yield b # type: ignore[misc] + b.reset(_BUDGET) + b.close() + + +@pytest.fixture() +async def async_backend(redis_url: str) -> AsyncRedisBackend: # type: ignore[return] + """Fresh AsyncRedisBackend — pytest-asyncio manages the event loop.""" + b = AsyncRedisBackend(url=redis_url) + await b.reset(_BUDGET) # clean slate before each test + yield b # type: ignore[misc] + await b.reset(_BUDGET) + await b.close() + + +# --------------------------------------------------------------------------- +# Group A — Basic sync check-and-add +# --------------------------------------------------------------------------- + + +def test_sync_allowed_within_limit(backend: RedisBackend) -> None: + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 2.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +def test_sync_rejected_when_limit_exceeded(backend: RedisBackend) -> None: + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + +def test_sync_cumulative_spend_triggers_rejection(backend: RedisBackend) -> None: + backend.check_and_add(_BUDGET, {"usd": 4.0}, {"usd": 5.0}, {"usd": 3600.0}) + allowed, exceeded = backend.check_and_add(_BUDGET, {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) + assert allowed is False + assert exceeded == "usd" + + +def test_sync_none_limit_never_rejected(backend: RedisBackend) -> None: + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 9999.0}, + limits={"usd": None}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +# --------------------------------------------------------------------------- +# Group B — All-or-nothing atomicity +# --------------------------------------------------------------------------- + + +def test_sync_all_or_nothing_second_counter_fails(backend: RedisBackend) -> None: + """When llm_calls exceeds limit, usd must NOT be committed.""" + # Pre-fill with both counters so spec_hash is consistent across calls. + backend.check_and_add( + _BUDGET, + amounts={"usd": 1.0, "llm_calls": 1.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + + # usd would be ok (1+1=2≤5) but llm_calls would exceed (1+101=102>100). + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 1.0, "llm_calls": 101.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + assert allowed is False + assert exceeded == "llm_calls" + + # usd must still be 1.0 — all-or-nothing means nothing was committed. + state = backend.get_state(_BUDGET) + assert state["usd"] == pytest.approx(1.0) + assert state["llm_calls"] == pytest.approx(1.0) + + +# --------------------------------------------------------------------------- +# Group C — Multi-cap +# --------------------------------------------------------------------------- + + +def test_sync_multi_cap_both_within_limit(backend: RedisBackend) -> None: + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 0.5, "llm_calls": 1.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +def test_sync_multi_cap_calls_counter_rejected(backend: RedisBackend) -> None: + backend.check_and_add( + _BUDGET, + amounts={"usd": 0.01, "llm_calls": 99.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 0.01, "llm_calls": 2.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + assert allowed is False + assert exceeded == "llm_calls" + + +# --------------------------------------------------------------------------- +# Group D — Window expiry (real clock, short window) +# --------------------------------------------------------------------------- + + +def test_sync_window_resets_after_expiry(backend: RedisBackend) -> None: + """Spend inside a 1 s window; after 1.1 s the counter resets.""" + backend.check_and_add(_BUDGET, {"usd": 4.5}, {"usd": 5.0}, {"usd": 1.0}) + + time.sleep(1.5) + + # New window: 4.5 should not carry over. + allowed, exceeded = backend.check_and_add(_BUDGET, {"usd": 4.5}, {"usd": 5.0}, {"usd": 1.0}) + assert allowed is True + assert exceeded is None + + +def test_sync_per_counter_independent_windows(backend: RedisBackend) -> None: + """usd window = 10 s, llm_calls window = 1 s. Only calls expire.""" + backend.check_and_add( + _BUDGET, + amounts={"usd": 4.0, "llm_calls": 90.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 10.0, "llm_calls": 1.0}, + ) + + time.sleep(1.5) + + # usd still accumulates (4.0 + 0.5 = 4.5 ≤ 5.0); calls start fresh (90 ≤ 100). + allowed, exceeded = backend.check_and_add( + _BUDGET, + amounts={"usd": 0.5, "llm_calls": 90.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 10.0, "llm_calls": 1.0}, + ) + assert allowed is True + assert exceeded is None + + state = backend.get_state(_BUDGET) + assert state["usd"] == pytest.approx(4.5) + assert state["llm_calls"] == pytest.approx(90.0) + + +# --------------------------------------------------------------------------- +# Group E — State & reset +# --------------------------------------------------------------------------- + + +def test_sync_get_state_reflects_spend(backend: RedisBackend) -> None: + backend.check_and_add( + _BUDGET, + amounts={"usd": 1.23, "llm_calls": 7.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + state = backend.get_state(_BUDGET) + assert state["usd"] == pytest.approx(1.23) + assert state["llm_calls"] == pytest.approx(7.0) + + +def test_sync_reset_clears_all_counters(backend: RedisBackend) -> None: + backend.check_and_add(_BUDGET, {"usd": 3.0}, {"usd": 5.0}, {"usd": 3600.0}) + backend.reset(_BUDGET) + state = backend.get_state(_BUDGET) + assert state == {} + + +def test_sync_get_state_empty_for_unknown_budget(backend: RedisBackend) -> None: + state = backend.get_state("no_such_budget_xyz") + assert state == {} + + +# --------------------------------------------------------------------------- +# Group F — Config mismatch +# --------------------------------------------------------------------------- + + +def test_config_mismatch_raises_error(backend: RedisBackend) -> None: + """Re-using a budget name with different limits raises BudgetConfigMismatchError.""" + backend.check_and_add(_BUDGET, {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + with pytest.raises(BudgetConfigMismatchError): + backend.check_and_add( + _BUDGET, + amounts={"usd": 0.01}, + limits={"usd": 10.0}, # different limit + windows={"usd": 3600.0}, + ) + + +# --------------------------------------------------------------------------- +# Group G — Multi-instance isolation +# --------------------------------------------------------------------------- + + +def test_two_budget_names_are_isolated(redis_url: str) -> None: + b = RedisBackend(url=redis_url) + try: + b.reset("budget_a") + b.reset("budget_b") + + b.check_and_add("budget_a", {"usd": 4.9}, {"usd": 5.0}, {"usd": 3600.0}) + + # budget_b is fresh — should be allowed. + allowed, exceeded = b.check_and_add("budget_b", {"usd": 4.9}, {"usd": 5.0}, {"usd": 3600.0}) + assert allowed is True + + state_a = b.get_state("budget_a") + state_b = b.get_state("budget_b") + assert state_a["usd"] == pytest.approx(4.9) + assert state_b["usd"] == pytest.approx(4.9) + finally: + b.reset("budget_a") + b.reset("budget_b") + b.close() + + +def test_shared_name_two_backend_instances(redis_url: str) -> None: + """Two independent RedisBackend objects share state when using the same budget name.""" + b1 = RedisBackend(url=redis_url) + b2 = RedisBackend(url=redis_url) + shared = "shared_budget" + try: + b1.reset(shared) + b1.check_and_add(shared, {"usd": 3.0}, {"usd": 5.0}, {"usd": 3600.0}) + + # b2 reads the same Redis key — sees the existing 3.0 spend. + allowed, exceeded = b2.check_and_add(shared, {"usd": 3.0}, {"usd": 5.0}, {"usd": 3600.0}) + assert allowed is False + assert exceeded == "usd" + finally: + b1.reset(shared) + b1.close() + b2.close() + + +# --------------------------------------------------------------------------- +# Group H — Async check-and-add +# --------------------------------------------------------------------------- + + +async def test_async_allowed_within_limit(async_backend: AsyncRedisBackend) -> None: + allowed, exceeded = await async_backend.check_and_add( + _BUDGET, + amounts={"usd": 2.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +async def test_async_rejected_when_limit_exceeded(async_backend: AsyncRedisBackend) -> None: + allowed, exceeded = await async_backend.check_and_add( + _BUDGET, + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + +async def test_async_window_resets_after_expiry(async_backend: AsyncRedisBackend) -> None: + await async_backend.check_and_add(_BUDGET, {"usd": 4.5}, {"usd": 5.0}, {"usd": 1.0}) + time.sleep(1.5) + allowed, exceeded = await async_backend.check_and_add( + _BUDGET, {"usd": 4.5}, {"usd": 5.0}, {"usd": 1.0} + ) + assert allowed is True + assert exceeded is None + + +async def test_async_get_state(async_backend: AsyncRedisBackend) -> None: + await async_backend.check_and_add( + _BUDGET, + amounts={"usd": 2.5, "llm_calls": 10.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + state = await async_backend.get_state(_BUDGET) + assert state["usd"] == pytest.approx(2.5) + assert state["llm_calls"] == pytest.approx(10.0) + + +async def test_async_reset(async_backend: AsyncRedisBackend) -> None: + await async_backend.check_and_add(_BUDGET, {"usd": 1.0}, {"usd": 5.0}, {"usd": 3600.0}) + await async_backend.reset(_BUDGET) + state = await async_backend.get_state(_BUDGET) + assert state == {} + + +# --------------------------------------------------------------------------- +# Group I — Full budget() factory integration +# --------------------------------------------------------------------------- + + +def test_budget_factory_enforces_usd_cap_with_redis(redis_url: str) -> None: + from shekel import budget + + b = budget("$0.001/hr", name="docker_usd", backend=RedisBackend(url=redis_url)) + b._backend.reset("docker_usd") + try: + with pytest.raises(BudgetExceededError): + b._record_spend(0.002, "gpt-4o-mini", {"input": 10, "output": 5}) + finally: + b._backend.reset("docker_usd") + b._backend.close() # type: ignore[union-attr] + + +def test_budget_factory_multi_cap_with_redis(redis_url: str) -> None: + from shekel import budget + + b = budget( + "$5/hr + 1 call/hr", + name="docker_multicap", + backend=RedisBackend(url=redis_url), + ) + b._backend.reset("docker_multicap") + try: + # First call allowed. + b._record_spend(0.001, "gpt-4o-mini", {"input": 10, "output": 5}) + + # Second call: llm_calls limit (1) exceeded. + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.001, "gpt-4o-mini", {"input": 10, "output": 5}) + + assert exc_info.value.exceeded_counter == "llm_calls" + finally: + b._backend.reset("docker_multicap") + b._backend.close() # type: ignore[union-attr] + + +def test_budget_factory_window_reset_allows_new_spend(redis_url: str) -> None: + """After a 1 s window expires, spend is allowed again.""" + from shekel import budget + + b = budget("$0.001/s", name="docker_window", backend=RedisBackend(url=redis_url)) + b._backend.reset("docker_window") + try: + with pytest.raises(BudgetExceededError): + b._record_spend(0.002, "gpt-4o-mini", {}) + + time.sleep(1.5) + + # Fresh window — should succeed. + b._record_spend(0.0005, "gpt-4o-mini", {}) + finally: + b._backend.reset("docker_window") + b._backend.close() # type: ignore[union-attr] diff --git a/tests/integrations/test_temporal_integration.py b/tests/integrations/test_temporal_integration.py index 5519b7f..8ef13c8 100644 --- a/tests/integrations/test_temporal_integration.py +++ b/tests/integrations/test_temporal_integration.py @@ -74,9 +74,8 @@ def test_cost_recorded_within_window(self) -> None: with tb: record(input_tokens=1000, output_tokens=500, model="gpt-4o-mini") - spent, window_start = backend.get_state("recording") - assert spent > 0 - assert window_start is not None + state = backend.get_state("recording") + assert state.get("usd", 0.0) > 0 def test_spend_accumulates_across_entries(self) -> None: """Window spend accumulates across multiple context entries.""" @@ -86,12 +85,12 @@ def test_spend_accumulates_across_entries(self) -> None: with tb: record(input_tokens=1000, output_tokens=500, model="gpt-4o-mini") - spent_after_first, _ = backend.get_state("accum") + spent_after_first = backend.get_state("accum").get("usd", 0.0) with tb: record(input_tokens=1000, output_tokens=500, model="gpt-4o-mini") - spent_after_second, _ = backend.get_state("accum") + spent_after_second = backend.get_state("accum").get("usd", 0.0) assert spent_after_second > spent_after_first def test_budget_spent_property_matches_cost(self) -> None: @@ -120,7 +119,7 @@ def test_error_has_retry_after(self) -> None: # Pre-fill window so retry_after is meaningful t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("retry", 0.0009, 0.001, 3600.0) + backend.check_and_add("retry", {"usd": 0.0009}, {"usd": 0.001}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 100.0): with pytest.raises(BudgetExceededError) as exc_info: @@ -137,7 +136,7 @@ def test_error_has_window_spent(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("ws_err", 0.0009, 0.001, 3600.0) + backend.check_and_add("ws_err", {"usd": 0.0009}, {"usd": 0.001}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 100.0): with pytest.raises(BudgetExceededError) as exc_info: @@ -172,7 +171,7 @@ def test_window_resets_after_expiry(self) -> None: # Fill the window to the brim with patch("time.monotonic", return_value=t0): - backend.check_and_add("reset", 0.0009, 0.001, 3600.0) + backend.check_and_add("reset", {"usd": 0.0009}, {"usd": 0.001}, {"usd": 3600.0}) # After window expires, the next entry starts fresh with patch("shekel._temporal.time.monotonic", return_value=t0 + 3601.0): @@ -191,7 +190,7 @@ def test_window_reset_emits_adapter_event(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("evt_reset", 2.0, 5.0, 3600.0) + backend.check_and_add("evt_reset", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 3601.0): with tb: @@ -220,13 +219,19 @@ def test_window_resets_allow_spend_again(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - assert backend.check_and_add("reuse", 0.001, 0.001, 3600.0) is True + allowed, _ = backend.check_and_add( + "reuse", {"usd": 0.001}, {"usd": 0.001}, {"usd": 3600.0} + ) + assert allowed is True # New window — same amount should be accepted again with patch("time.monotonic", return_value=t0 + 3601.0): - assert backend.check_and_add("reuse", 0.001, 0.001, 3600.0) is True - spent, _ = backend.get_state("reuse") - assert spent == pytest.approx(0.001) + allowed2, _ = backend.check_and_add( + "reuse", {"usd": 0.001}, {"usd": 0.001}, {"usd": 3600.0} + ) + assert allowed2 is True + state = backend.get_state("reuse") + assert state["usd"] == pytest.approx(0.001) # --------------------------------------------------------------------------- @@ -311,7 +316,7 @@ def test_multiple_adapters_all_receive_window_reset(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("multi_adapt", 2.0, 5.0, 3600.0) + backend.check_and_add("multi_adapt", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 3601.0): with tb: @@ -448,7 +453,9 @@ async def _run(): ) t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("async_retry", 0.0009, 0.001, 3600.0) + backend.check_and_add( + "async_retry", {"usd": 0.0009}, {"usd": 0.001}, {"usd": 3600.0} + ) with patch("time.monotonic", return_value=t0 + 100.0): async with tb: @@ -471,7 +478,7 @@ async def _run(): ) t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("async_reset", 2.0, 5.0, 3600.0) + backend.check_and_add("async_reset", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 3601.0): async with tb: @@ -534,8 +541,8 @@ def make_budget(user_id: str) -> TemporalBudget: with make_budget("bob"): record(input_tokens=500, output_tokens=250, model="gpt-4o-mini") - alice_spent, _ = shared_backend.get_state("user:alice") - bob_spent, _ = shared_backend.get_state("user:bob") + alice_spent = shared_backend.get_state("user:alice").get("usd", 0.0) + bob_spent = shared_backend.get_state("user:bob").get("usd", 0.0) assert alice_spent > 0 assert bob_spent > 0 @@ -573,17 +580,24 @@ def test_window_reset_for_one_user_does_not_affect_others(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - shared_backend.check_and_add("user:alice", 0.0008, 0.001, 3600.0) - shared_backend.check_and_add("user:bob", 0.0005, 0.001, 3600.0) + shared_backend.check_and_add( + "user:alice", {"usd": 0.0008}, {"usd": 0.001}, {"usd": 3600.0} + ) + shared_backend.check_and_add( + "user:bob", {"usd": 0.0005}, {"usd": 0.001}, {"usd": 3600.0} + ) # Alice's window expires; bob's is still active with patch("time.monotonic", return_value=t0 + 3601.0): # Alice's window reset: check_and_add should start fresh for alice - result = shared_backend.check_and_add("user:alice", 0.0008, 0.001, 3600.0) + result, _ = shared_backend.check_and_add( + "user:alice", {"usd": 0.0008}, {"usd": 0.001}, {"usd": 3600.0} + ) assert result is True # accepted in fresh window # Bob's window should still have his original spend - bob_spent, bob_start = shared_backend.get_state("user:bob") + bob_state = shared_backend.get_window_info("user:bob") + bob_spent, bob_start = bob_state.get("usd", (0.0, None)) assert bob_spent == pytest.approx(0.0005) assert bob_start == pytest.approx(t0) @@ -680,7 +694,7 @@ def test_otel_adapter_receives_window_reset_event(self) -> None: t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("otel_reset", 2.0, 5.0, 3600.0) + backend.check_and_add("otel_reset", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 3601.0): with tb: diff --git a/tests/test_crewai_wrappers.py b/tests/test_crewai_wrappers.py new file mode 100644 index 0000000..0edec06 --- /dev/null +++ b/tests/test_crewai_wrappers.py @@ -0,0 +1,534 @@ +"""Tests for CrewAI agent/task-level budget enforcement (v0.3.1). + +Domain: CrewAIExecutionAdapter — patching, agent/task gate, spend attribution, +nested budget inheritance, silent-miss warnings. +""" + +from __future__ import annotations + +import sys +import types +import warnings +from typing import Any + +import pytest + +import shekel.providers.crewai as crewai_mod +from shekel import budget +from shekel._budget import Budget +from shekel._runtime import ShekelRuntime +from shekel.exceptions import AgentBudgetExceededError, TaskBudgetExceededError +from shekel.providers.crewai import CrewAIExecutionAdapter + +# --------------------------------------------------------------------------- +# Helpers — fake crewai.agent module injection +# --------------------------------------------------------------------------- + + +def _make_crewai_modules(simulated_cost: float = 0.0) -> tuple[types.ModuleType, type]: + """Return (fake_crewai_agent_mod, Agent class) injected into sys.modules.""" + fake_crewai = types.ModuleType("crewai") + fake_agent_mod = types.ModuleType("crewai.agent") + + class Agent: + def __init__(self, role: str = "Senior Researcher") -> None: + self.role = role + + def execute_task(self, task: Any, context: Any = None, tools: Any = None) -> str: + if simulated_cost > 0: + from shekel._context import get_active_budget + + active = get_active_budget() + if active is not None: + active._spent += simulated_cost + active._spent_direct += simulated_cost + return "done" + + fake_agent_mod.Agent = Agent # type: ignore[attr-defined] + sys.modules["crewai"] = fake_crewai # type: ignore[assignment] + sys.modules["crewai.agent"] = fake_agent_mod # type: ignore[assignment] + return fake_agent_mod, Agent + + +def _cleanup_crewai_modules() -> None: + for key in ["crewai", "crewai.agent"]: + sys.modules.pop(key, None) + + +class MockTask: + def __init__(self, name: str = "research", description: str = "Do research about AI") -> None: + self.name = name + self.description = description + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def restore_adapter_state(): + """Restore CrewAIExecutionAdapter patch state and ShekelRuntime registry.""" + original_refcount = crewai_mod._execution_patch_refcount + original_execute_task = crewai_mod._original_execute_task + original_registry = ShekelRuntime._adapter_registry[:] + + yield + + # Restore module state + crewai_mod._execution_patch_refcount = original_refcount + crewai_mod._original_execute_task = original_execute_task + + # Restore registry + ShekelRuntime._adapter_registry = original_registry + + # Clean up injected fake modules + _cleanup_crewai_modules() + + +# --------------------------------------------------------------------------- +# Group 0: Smoke-check — Agent.execute_task signature +# --------------------------------------------------------------------------- + + +def test_execute_task_signature_has_task_param() -> None: + """Fake Agent.execute_task has the expected 'task' parameter.""" + import inspect + + _, Agent = _make_crewai_modules() + sig = inspect.signature(Agent.execute_task) + assert "task" in sig.parameters + + +# --------------------------------------------------------------------------- +# Group 1: CrewAIExecutionAdapter registered in ShekelRuntime +# --------------------------------------------------------------------------- + + +def test_crewai_execution_adapter_in_runtime_registry() -> None: + """CrewAIExecutionAdapter is registered in ShekelRuntime at import time.""" + assert CrewAIExecutionAdapter in ShekelRuntime._adapter_registry + + +# --------------------------------------------------------------------------- +# Group 2: install_patches / remove_patches lifecycle +# --------------------------------------------------------------------------- + + +def test_install_patches_raises_import_error_when_crewai_absent() -> None: + """install_patches() raises ImportError when crewai.agent is not importable.""" + _cleanup_crewai_modules() + sys.modules["crewai"] = None # type: ignore[assignment] + sys.modules["crewai.agent"] = None # type: ignore[assignment] + adapter = CrewAIExecutionAdapter() + with pytest.raises(ImportError): + adapter.install_patches(Budget(max_usd=5.00)) + + +def test_install_patches_patches_execute_task() -> None: + """install_patches() replaces Agent.execute_task with patched version.""" + _, Agent = _make_crewai_modules() + original = Agent.execute_task + adapter = CrewAIExecutionAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + assert Agent.execute_task is not original + adapter.remove_patches(b) + + +def test_remove_patches_restores_original_execute_task() -> None: + """remove_patches() restores Agent.execute_task to the original.""" + _, Agent = _make_crewai_modules() + original = Agent.execute_task + adapter = CrewAIExecutionAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + adapter.remove_patches(b) + assert Agent.execute_task is original + + +def test_remove_patches_noop_when_not_installed() -> None: + """remove_patches() is idempotent when called without a prior install.""" + _make_crewai_modules() + adapter = CrewAIExecutionAdapter() + b = Budget(max_usd=5.00) + # Should not raise + adapter.remove_patches(b) + adapter.remove_patches(b) + + +def test_refcount_nested_budgets_do_not_double_patch() -> None: + """Second install_patches() call increments refcount but does not re-patch.""" + _, Agent = _make_crewai_modules() + adapter = CrewAIExecutionAdapter() + b1 = Budget(max_usd=5.00) + b2 = Budget(max_usd=3.00) + adapter.install_patches(b1) + patched = Agent.execute_task + adapter.install_patches(b2) + assert Agent.execute_task is patched # not re-patched + assert crewai_mod._execution_patch_refcount == 2 + adapter.remove_patches(b2) + adapter.remove_patches(b1) + + +def test_refcount_patch_not_removed_until_last_budget_exits() -> None: + """Patch is retained until the outermost budget exits.""" + _, Agent = _make_crewai_modules() + adapter = CrewAIExecutionAdapter() + b1 = Budget(max_usd=5.00) + b2 = Budget(max_usd=3.00) + adapter.install_patches(b1) + adapter.install_patches(b2) + patched = Agent.execute_task + adapter.remove_patches(b2) + assert Agent.execute_task is patched # still patched + adapter.remove_patches(b1) + + +# --------------------------------------------------------------------------- +# Group 3: Pre-execution gate — agent cap +# --------------------------------------------------------------------------- + + +def test_agent_cap_exceeded_raises_before_execute() -> None: + """AgentBudgetExceededError raised before execute_task body runs.""" + _, Agent = _make_crewai_modules(simulated_cost=0.05) + + with budget(max_usd=5.00) as b: + b.agent("Senior Researcher", max_usd=0.10) + b._agent_budgets["Senior Researcher"]._spent = 0.10 # exhaust cap + + agent = Agent(role="Senior Researcher") + task = MockTask() + with pytest.raises(AgentBudgetExceededError) as exc_info: + agent.execute_task(task) + + # No spend attributed — gate fires before body runs + assert b._spent == 0.0 + + assert exc_info.value.agent_name == "Senior Researcher" + + +def test_agent_cap_not_exceeded_allows_execute() -> None: + """No exception when agent cap has remaining budget.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.agent("Senior Researcher", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + result = agent.execute_task(task) + + assert result == "done" + assert b._agent_budgets["Senior Researcher"]._spent == 0.0 + + +def test_agent_cap_on_outer_budget_enforced_in_nested_inner_budget() -> None: + """Agent cap on outer budget is enforced when execution runs in inner context.""" + _make_crewai_modules() + with budget(max_usd=5.00, name="outer") as outer: + outer.agent("Senior Researcher", max_usd=0.10) + outer._agent_budgets["Senior Researcher"]._spent = 0.10 + + with budget(max_usd=2.00, name="inner"): + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + with pytest.raises(AgentBudgetExceededError): + agent.execute_task(task) + + +def test_agent_cap_found_on_grandparent_budget() -> None: + """Agent cap registered on grandparent is enforced two levels deep.""" + _make_crewai_modules() + with budget(max_usd=10.00, name="root") as root: + root.agent("Senior Researcher", max_usd=0.10) + root._agent_budgets["Senior Researcher"]._spent = 0.10 + + with budget(max_usd=5.00, name="mid"): + with budget(max_usd=2.00, name="inner"): + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + with pytest.raises(AgentBudgetExceededError): + agent.execute_task(task) + + +def test_global_budget_exhausted_raises_agent_budget_exceeded_error() -> None: + """AgentBudgetExceededError raised when global budget is exhausted.""" + _make_crewai_modules() + with budget(max_usd=0.10) as b: + b._spent = 0.10 # exhaust global budget + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + with pytest.raises(AgentBudgetExceededError): + agent.execute_task(task) + + +# --------------------------------------------------------------------------- +# Group 4: Pre-execution gate — task cap +# --------------------------------------------------------------------------- + + +def test_task_cap_exceeded_raises_before_execute() -> None: + """TaskBudgetExceededError raised before execute_task body runs.""" + _make_crewai_modules(simulated_cost=0.05) + + with budget(max_usd=5.00) as b: + b.task("research", max_usd=0.10) + b._task_budgets["research"]._spent = 0.10 + + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + + with pytest.raises(TaskBudgetExceededError) as exc_info: + agent.execute_task(task) + + # No spend attributed — gate fires before body runs + assert b._spent == 0.0 + + assert exc_info.value.task_name == "research" + + +def test_task_cap_not_exceeded_allows_execute() -> None: + """No exception when task cap has remaining budget.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("research", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + result = agent.execute_task(task) + + assert result == "done" + + +def test_task_cap_takes_precedence_over_agent_cap() -> None: + """TaskBudgetExceededError fires when both task and agent caps are exceeded.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("research", max_usd=0.10) + b.agent("Senior Researcher", max_usd=0.10) + b._task_budgets["research"]._spent = 0.10 + b._agent_budgets["Senior Researcher"]._spent = 0.10 + + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + with pytest.raises(TaskBudgetExceededError): + agent.execute_task(task) + + +def test_task_name_falls_back_to_description() -> None: + """Task name falls back to description when name is None.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("Do research about AI", max_usd=0.10) + b._task_budgets["Do research about AI"]._spent = 0.10 + + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name=None, description="Do research about AI") # type: ignore[arg-type] + with pytest.raises(TaskBudgetExceededError) as exc_info: + agent.execute_task(task) + + assert exc_info.value.task_name == "Do research about AI" + + +def test_task_name_empty_string_falls_back_to_description() -> None: + """Empty string task name is treated as absent, falls back to description.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("Do research about AI", max_usd=0.10) + b._task_budgets["Do research about AI"]._spent = 0.10 + + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="", description="Do research about AI") + with pytest.raises(TaskBudgetExceededError) as exc_info: + agent.execute_task(task) + + assert exc_info.value.task_name == "Do research about AI" + + +# --------------------------------------------------------------------------- +# Group 5: Spend attribution +# --------------------------------------------------------------------------- + + +def test_spend_attributed_to_agent_component_budget() -> None: + """Spend delta after execute_task is attributed to agent ComponentBudget.""" + _make_crewai_modules(simulated_cost=0.05) + with budget(max_usd=5.00) as b: + b.agent("Senior Researcher", max_usd=2.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + agent.execute_task(task) + + assert b._agent_budgets["Senior Researcher"]._spent == pytest.approx(0.05) + + +def test_spend_attributed_to_task_component_budget() -> None: + """Spend delta after execute_task is attributed to task ComponentBudget.""" + _make_crewai_modules(simulated_cost=0.05) + with budget(max_usd=5.00) as b: + b.task("research", max_usd=2.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + agent.execute_task(task) + + assert b._task_budgets["research"]._spent == pytest.approx(0.05) + + +def test_spend_attributed_to_both_agent_and_task() -> None: + """Same spend delta attributed to both agent and task ComponentBudgets.""" + _make_crewai_modules(simulated_cost=0.07) + with budget(max_usd=5.00) as b: + b.agent("Senior Researcher", max_usd=2.00) + b.task("research", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + agent.execute_task(task) + + assert b._agent_budgets["Senior Researcher"]._spent == pytest.approx(0.07) + assert b._task_budgets["research"]._spent == pytest.approx(0.07) + + +def test_zero_spend_task_does_not_update_component_budgets() -> None: + """Zero-cost execution does not change ComponentBudget._spent.""" + _make_crewai_modules(simulated_cost=0.0) + with budget(max_usd=5.00) as b: + b.agent("Senior Researcher", max_usd=2.00) + b.task("research", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name="research") + agent.execute_task(task) + + assert b._agent_budgets["Senior Researcher"]._spent == 0.0 + assert b._task_budgets["research"]._spent == 0.0 + + +# --------------------------------------------------------------------------- +# Group 6: Silent-miss warnings +# --------------------------------------------------------------------------- + + +def test_unnamed_task_with_registered_cap_emits_warning() -> None: + """warnings.warn emitted when task has no name and task caps are registered.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("something", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name=None, description="Do research") # type: ignore[arg-type] + with pytest.warns(UserWarning, match="task has no name"): + agent.execute_task(task) + + +def test_unnamed_task_warning_includes_description() -> None: + """Warning message includes (truncated) task description.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("something", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name=None, description="A" * 100) # type: ignore[arg-type] + with pytest.warns(UserWarning, match="description:"): + agent.execute_task(task) + + +def test_unnamed_task_no_description_warning_message() -> None: + """Warning message uses alternate text when task has no name AND no description.""" + _make_crewai_modules() + with budget(max_usd=5.00) as b: + b.task("something", max_usd=1.00) + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name=None, description="") # type: ignore[arg-type] + with pytest.warns(UserWarning, match="no name and no description"): + agent.execute_task(task) + + +def test_unnamed_task_without_registered_cap_no_warning() -> None: + """No warning emitted when task has no name but no task caps are registered.""" + _make_crewai_modules() + with budget(max_usd=5.00): + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask(name=None, description="Do research") # type: ignore[arg-type] + with warnings.catch_warnings(): + warnings.simplefilter("error") + agent.execute_task(task) # must not raise WarningException + + +# --------------------------------------------------------------------------- +# Group 7: Passthrough — no active budget +# --------------------------------------------------------------------------- + + +def test_no_active_budget_execute_task_passthrough() -> None: + """execute_task runs normally when no budget() context is active.""" + _make_crewai_modules() + adapter = CrewAIExecutionAdapter() + adapter.install_patches(Budget(max_usd=5.00)) + + from crewai.agent import Agent + + agent = Agent(role="Senior Researcher") + task = MockTask() + # No budget context — should run through without error + result = agent.execute_task(task) + assert result == "done" + + +# --------------------------------------------------------------------------- +# Group 8: ShekelRuntime integration +# --------------------------------------------------------------------------- + + +def test_runtime_probe_installs_crewai_adapter_when_crewai_installed() -> None: + """ShekelRuntime.probe() activates CrewAIExecutionAdapter when crewai is present.""" + _, Agent = _make_crewai_modules() + original = Agent.execute_task + b = Budget(max_usd=5.00) + ShekelRuntime._adapter_registry = [CrewAIExecutionAdapter] + runtime = ShekelRuntime(b) + runtime.probe() + assert Agent.execute_task is not original + runtime.release() + + +def test_runtime_release_removes_crewai_adapter() -> None: + """ShekelRuntime.release() restores Agent.execute_task.""" + _, Agent = _make_crewai_modules() + original = Agent.execute_task + b = Budget(max_usd=5.00) + ShekelRuntime._adapter_registry = [CrewAIExecutionAdapter] + runtime = ShekelRuntime(b) + runtime.probe() + runtime.release() + assert Agent.execute_task is original diff --git a/tests/test_distributed_budgets.py b/tests/test_distributed_budgets.py new file mode 100644 index 0000000..8ba71dd --- /dev/null +++ b/tests/test_distributed_budgets.py @@ -0,0 +1,1246 @@ +"""Tests for distributed budget features. + +Domain: distributed budgets — multi-cap spec parsing, new backend protocol, +Redis backend, BudgetConfigMismatchError, on_backend_unavailable. +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Group A — Multi-cap spec parsing (_parse_cap_spec) +# --------------------------------------------------------------------------- + + +def test_parse_cap_spec_single_usd(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("$5/hr") + assert caps == [("usd", 5.0, 3600.0)] + + +def test_parse_cap_spec_single_calls(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("100 calls/hr") + assert caps == [("llm_calls", 100.0, 3600.0)] + + +def test_parse_cap_spec_single_tools(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("20 tools/hr") + assert caps == [("tool_calls", 20.0, 3600.0)] + + +def test_parse_cap_spec_usd_keyword(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("5 usd/hr") + assert caps == [("usd", 5.0, 3600.0)] + + +def test_parse_cap_spec_multi_same_window(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("$5/hr + 100 calls/hr") + assert len(caps) == 2 + assert ("usd", 5.0, 3600.0) in caps + assert ("llm_calls", 100.0, 3600.0) in caps + + +def test_parse_cap_spec_multi_different_windows(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("$5/hr + 100 calls/30min") + assert len(caps) == 2 + assert ("usd", 5.0, 3600.0) in caps + assert ("llm_calls", 100.0, 1800.0) in caps + + +def test_parse_cap_spec_three_caps(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("$5/hr + 100 calls/hr + 20 tools/hr") + assert len(caps) == 3 + counters = [c[0] for c in caps] + assert "usd" in counters + assert "llm_calls" in counters + assert "tool_calls" in counters + + +def test_parse_cap_spec_with_window_count(): + from shekel._temporal import _parse_cap_spec + + caps = _parse_cap_spec("$10/30min") + assert caps == [("10", 10.0, 1800.0)] or caps == [("usd", 10.0, 1800.0)] + # Either form is acceptable, key is the window_s + counter, amount, window_s = caps[0] + assert amount == 10.0 + assert window_s == 1800.0 + + +def test_parse_cap_spec_rejects_calendar_unit(): + from shekel._temporal import _parse_cap_spec + + with pytest.raises(ValueError): + _parse_cap_spec("$5/day") + + +def test_parse_cap_spec_rejects_unknown_cap_type(): + from shekel._temporal import _parse_cap_spec + + with pytest.raises(ValueError): + _parse_cap_spec("100 widgets/hr") + + +def test_parse_cap_spec_rejects_garbage(): + from shekel._temporal import _parse_cap_spec + + with pytest.raises(ValueError): + _parse_cap_spec("hello world") + + +def test_parse_cap_spec_rejects_zero_amount(): + from shekel._temporal import _parse_cap_spec + + with pytest.raises(ValueError): + _parse_cap_spec("$0/hr") + + +def test_parse_cap_spec_singular_forms(): + """'call' and 'tool' (singular) should be accepted.""" + from shekel._temporal import _parse_cap_spec + + caps_call = _parse_cap_spec("1 call/hr") + assert caps_call[0][0] == "llm_calls" + + caps_tool = _parse_cap_spec("1 tool/hr") + assert caps_tool[0][0] == "tool_calls" + + +# --------------------------------------------------------------------------- +# Group B — budget() factory: form mixing and multi-cap +# --------------------------------------------------------------------------- + + +def test_budget_factory_rejects_spec_with_max_usd(): + from shekel import budget + + with pytest.raises(ValueError, match="[Mm]ix"): + budget("$5/hr", name="api", max_usd=10.0) + + +def test_budget_factory_rejects_spec_with_max_llm_calls(): + from shekel import budget + + with pytest.raises(ValueError, match="[Mm]ix"): + budget("$5/hr", name="api", max_llm_calls=100) + + +def test_budget_factory_rejects_spec_with_window_seconds(): + from shekel import budget + + with pytest.raises(ValueError, match="[Mm]ix"): + budget("$5/hr", name="api", window_seconds=3600) + + +def test_budget_factory_multi_cap_spec_string(): + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget("$5/hr + 100 calls/hr", name="api") + assert isinstance(b, TemporalBudget) + + +def test_budget_factory_multi_cap_spec_has_both_caps(): + from shekel import budget + + b = budget("$5/hr + 100 calls/hr", name="api") + assert "usd" in b._caps + assert "llm_calls" in b._caps + + +def test_budget_factory_multi_cap_kwargs(): + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget(max_usd=5.0, max_llm_calls=100, window_seconds=3600, name="api") + assert isinstance(b, TemporalBudget) + assert "usd" in b._caps + assert "llm_calls" in b._caps + + +def test_budget_factory_single_calls_spec(): + """budget('100 calls/hr', name='x') — no USD cap.""" + from shekel import budget + from shekel._temporal import TemporalBudget + + b = budget("100 calls/hr", name="x") + assert isinstance(b, TemporalBudget) + assert "llm_calls" in b._caps + + +# --------------------------------------------------------------------------- +# Group C — New InMemoryBackend multi-cap protocol +# --------------------------------------------------------------------------- + + +def test_new_inmemory_get_state_fresh(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + state = backend.get_state("new_key") + assert state == {} + + +def test_new_inmemory_check_and_add_within_limits(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 2.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +def test_new_inmemory_check_and_add_updates_state(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + backend.check_and_add("b1", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) + state = backend.get_state("b1") + assert state.get("usd") == pytest.approx(2.0) + + +def test_new_inmemory_check_and_add_usd_exceeded(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + +def test_new_inmemory_check_and_add_calls_exceeded(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + # Fill up to limit + backend.check_and_add("b1", {"llm_calls": 99.0}, {"llm_calls": 100.0}, {"llm_calls": 3600.0}) + # One more should exceed + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"llm_calls": 2.0}, + limits={"llm_calls": 100.0}, + windows={"llm_calls": 3600.0}, + ) + assert allowed is False + assert exceeded == "llm_calls" + + +def test_new_inmemory_check_usd_before_calls(): + """When both would exceed, usd is reported (checked first).""" + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 6.0, "llm_calls": 101.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + +def test_new_inmemory_none_limit_always_allowed(): + """None limit = no cap; counter is still tracked.""" + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 999.0}, + limits={"usd": None}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + assert backend.get_state("b1")["usd"] == pytest.approx(999.0) + + +def test_new_inmemory_all_or_nothing(): + """If second counter fails, first counter is NOT incremented.""" + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + # First add some usd + backend.check_and_add("b1", {"usd": 1.0}, {"usd": 5.0}, {"usd": 3600.0}) + + # Now try to add where llm_calls would exceed but usd would not + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 1.0, "llm_calls": 101.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 3600.0}, + ) + assert allowed is False + assert exceeded == "llm_calls" + # USD should NOT have been incremented (all-or-nothing) + state = backend.get_state("b1") + assert state["usd"] == pytest.approx(1.0) + + +def test_new_inmemory_reset_clears_state(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + backend.check_and_add("b1", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) + backend.reset("b1") + assert backend.get_state("b1") == {} + + +def test_new_inmemory_per_counter_window_expiry(): + """Each counter can have its own window; they expire independently.""" + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + t0 = 1000.0 + + with patch("time.monotonic", return_value=t0): + backend.check_and_add( + "b1", + amounts={"usd": 4.0, "llm_calls": 90.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 60.0}, # calls window is 1min + ) + + # After 2 minutes: llm_calls window expired, usd window still active + with patch("time.monotonic", return_value=t0 + 120.0): + allowed, exceeded = backend.check_and_add( + "b1", + amounts={"usd": 0.5, "llm_calls": 90.0}, + limits={"usd": 5.0, "llm_calls": 100.0}, + windows={"usd": 3600.0, "llm_calls": 60.0}, + ) + # usd: 4.0 + 0.5 = 4.5 <= 5.0 ✓; llm_calls: window expired → reset to 0, 90 <= 100 ✓ + assert allowed is True + state = backend.get_state("b1") + assert state["usd"] == pytest.approx(4.5) + assert state["llm_calls"] == pytest.approx(90.0) # fresh window + + +def test_new_inmemory_window_start_set_on_first_add(): + from shekel._temporal import InMemoryBackend + + backend = InMemoryBackend() + backend.check_and_add("b1", {"usd": 1.0}, {"usd": 5.0}, {"usd": 3600.0}) + # get_window_info is an optional observability method + if hasattr(backend, "get_window_info"): + info = backend.get_window_info("b1") + _, window_start = info["usd"] + assert window_start is not None + + +# --------------------------------------------------------------------------- +# Group D — TemporalBudget multi-cap behavior +# --------------------------------------------------------------------------- + + +def test_temporal_multicap_usd_only_still_works(): + """Backward-compat: single USD cap still works as before.""" + from shekel._temporal import TemporalBudget + from shekel.exceptions import BudgetExceededError + + tb = TemporalBudget(max_usd=5.0, window_seconds=3600, name="usd_only") + + with pytest.raises(BudgetExceededError): + tb._record_spend(6.0, "model", {"input": 100, "output": 100}) + + +def test_temporal_multicap_spec_raises_on_usd_exceed(): + from shekel import budget + from shekel.exceptions import BudgetExceededError + + b = budget("$5/hr + 100 calls/hr", name="api") + + with pytest.raises(BudgetExceededError): + b._record_spend(6.0, "model", {"input": 100, "output": 100}) + + +def test_temporal_multicap_spec_raises_on_calls_exceed(): + from shekel import budget + from shekel.exceptions import BudgetExceededError + + b = budget("$5/hr + 1 call/hr", name="api") # 1 call limit + + # First call should succeed + # Second call should fail (call limit exceeded) + b._record_spend(0.001, "model", {"input": 10, "output": 10}) + + with pytest.raises(BudgetExceededError): + b._record_spend(0.001, "model", {"input": 10, "output": 10}) + + +def test_temporal_multicap_error_has_exceeded_counter(): + from shekel import budget + from shekel.exceptions import BudgetExceededError + + b = budget("$5/hr + 1 call/hr", name="api") + b._record_spend(0.001, "model", {}) # first call ok + + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.001, "model", {}) + + # The error should indicate which counter was exceeded + assert exc_info.value.exceeded_counter == "llm_calls" + + +def test_temporal_multicap_kwargs_form(): + from shekel import budget + from shekel.exceptions import BudgetExceededError + + b = budget(max_usd=5.0, max_llm_calls=2, window_seconds=3600, name="kwarg_api") + + b._record_spend(0.001, "model", {}) + b._record_spend(0.001, "model", {}) + + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.001, "model", {}) + + assert exc_info.value.exceeded_counter == "llm_calls" + + +def test_temporal_multicap_caps_dict_structure(): + """_caps stores {counter: (limit, window_s)}.""" + from shekel import budget + + b = budget("$5/hr + 100 calls/30min", name="test") + assert b._caps["usd"] == (5.0, 3600.0) + assert b._caps["llm_calls"] == (100.0, 1800.0) + + +def test_temporal_no_usd_cap_calls_only(): + from shekel import budget + from shekel.exceptions import BudgetExceededError + + b = budget("1 call/hr", name="calls_only") + + b._record_spend(0.0, "model", {}) # first call ok + + with pytest.raises(BudgetExceededError) as exc_info: + b._record_spend(0.0, "model", {}) + + assert exc_info.value.exceeded_counter == "llm_calls" + + +# --------------------------------------------------------------------------- +# Group E — BudgetConfigMismatchError +# --------------------------------------------------------------------------- + + +def test_budget_config_mismatch_error_exists(): + from shekel.exceptions import BudgetConfigMismatchError + + err = BudgetConfigMismatchError("Budget 'api' already registered with different limits") + assert "api" in str(err) + + +def test_budget_config_mismatch_error_is_exception(): + from shekel.exceptions import BudgetConfigMismatchError + + with pytest.raises(BudgetConfigMismatchError): + raise BudgetConfigMismatchError("mismatch") + + +def test_budget_config_mismatch_error_exported(): + """BudgetConfigMismatchError is importable from shekel.exceptions.""" + from shekel import exceptions + + assert hasattr(exceptions, "BudgetConfigMismatchError") + + +# --------------------------------------------------------------------------- +# Group F — on_backend_unavailable observability event +# --------------------------------------------------------------------------- + + +def test_on_backend_unavailable_noop_in_base(): + from shekel.integrations.base import ObservabilityAdapter + + adapter = ObservabilityAdapter() + # Should not raise + adapter.on_backend_unavailable({"budget_name": "api", "error": "timeout"}) + + +def test_on_backend_unavailable_receives_correct_fields(): + from shekel.integrations.base import ObservabilityAdapter + + received: list[dict[str, Any]] = [] + + class TestAdapter(ObservabilityAdapter): + def on_backend_unavailable(self, data: dict[str, Any]) -> None: + received.append(data) + + adapter = TestAdapter() + adapter.on_backend_unavailable({"budget_name": "api", "error": "conn refused"}) + + assert len(received) == 1 + assert received[0]["budget_name"] == "api" + assert "error" in received[0] + + +# --------------------------------------------------------------------------- +# Group G — RedisBackend unit tests (mocked redis) +# --------------------------------------------------------------------------- + + +def _make_redis_mock(lua_result: Any = None) -> MagicMock: + """Create a mock Redis client that returns lua_result from evalsha/eval.""" + mock_client = MagicMock() + mock_client.script_load.return_value = "fakescriptsha" + if lua_result is None: + lua_result = [1, b""] # allowed=1, exceeded_counter="" + mock_client.evalsha.return_value = lua_result + mock_client.hgetall.return_value = {b"usd:spent": b"2.0"} + mock_client.delete.return_value = 1 + return mock_client + + +def test_redis_backend_importable(): + """RedisBackend is importable (redis is an optional dep — skip if not installed).""" + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + assert RedisBackend is not None + + +def test_redis_backend_check_and_add_allowed(): + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = _make_redis_mock(lua_result=[1, b""]) + backend = RedisBackend() + backend._client = mock_client # inject mock + backend._script_sha = "fakescriptsha" + + allowed, exceeded = backend.check_and_add( + "api", + amounts={"usd": 0.01}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + +def test_redis_backend_check_and_add_rejected(): + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = _make_redis_mock(lua_result=[0, b"usd"]) + backend = RedisBackend() + backend._client = mock_client + backend._script_sha = "fakescriptsha" + + allowed, exceeded = backend.check_and_add( + "api", + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + +def test_redis_backend_fail_closed_on_error(): + """When Redis is unreachable, default behavior raises BudgetExceededError.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.side_effect = redis_lib.RedisError("connection refused") + + backend = RedisBackend(on_unavailable="closed") + backend._client = mock_client + backend._script_sha = "sha" + + with pytest.raises(BudgetExceededError, match="[Uu]navailable|[Bb]ackend"): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + +def test_redis_backend_fail_open_when_configured(): + """With on_unavailable='open', Redis errors allow the call through.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.side_effect = redis_lib.RedisError("connection refused") + + backend = RedisBackend(on_unavailable="open") + backend._client = mock_client + backend._script_sha = "sha" + + allowed, exceeded = backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + assert allowed is True + assert exceeded is None + + +def test_redis_backend_emits_on_backend_unavailable_event(): + """When backend is unavailable, on_backend_unavailable event is emitted.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetExceededError + from shekel.integrations import AdapterRegistry + + AdapterRegistry.clear() + mock_adapter = MagicMock() + AdapterRegistry.register(mock_adapter) + + try: + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = RedisBackend(on_unavailable="closed") + backend._client = mock_client + backend._script_sha = "sha" + + with pytest.raises(BudgetExceededError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + mock_adapter.on_backend_unavailable.assert_called_once() + call_kwargs = mock_adapter.on_backend_unavailable.call_args[0][0] + assert call_kwargs["budget_name"] == "api" + finally: + AdapterRegistry.clear() + + +def test_redis_backend_circuit_breaker_stops_calling_redis(): + """After N consecutive errors, circuit breaker stops calling Redis.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = RedisBackend( + on_unavailable="closed", + circuit_breaker_threshold=3, + circuit_breaker_cooldown=10.0, + ) + backend._client = mock_client + backend._script_sha = "sha" + + # Trigger 3 consecutive errors to open circuit breaker + for _ in range(3): + with pytest.raises(BudgetExceededError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + evalsha_calls_before = mock_client.evalsha.call_count + + # Next call — circuit is open, should NOT call Redis + with pytest.raises(BudgetExceededError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + # evalsha should not have been called again + assert mock_client.evalsha.call_count == evalsha_calls_before + + +def test_redis_backend_circuit_breaker_resets_after_cooldown(): + """Circuit breaker allows retry after cooldown period.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = RedisBackend( + on_unavailable="closed", + circuit_breaker_threshold=3, + circuit_breaker_cooldown=10.0, + ) + backend._client = mock_client + backend._script_sha = "sha" + + t0 = 1000.0 + with patch("time.monotonic", return_value=t0): + for _ in range(3): + with pytest.raises(BudgetExceededError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + evalsha_before = mock_client.evalsha.call_count + + # After cooldown, circuit should close and retry Redis + with patch("time.monotonic", return_value=t0 + 11.0): + with pytest.raises(BudgetExceededError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + assert mock_client.evalsha.call_count > evalsha_before + + +def test_redis_backend_config_mismatch_raises(): + """Spec hash mismatch raises BudgetConfigMismatchError.""" + pytest.importorskip("redis") + + from shekel.backends.redis import RedisBackend + from shekel.exceptions import BudgetConfigMismatchError + + mock_client = MagicMock() + mock_client.script_load.return_value = "sha" + # Simulate mismatch: Lua returns -1 or special sentinel for mismatch + mock_client.evalsha.return_value = [-2, b"spec_mismatch"] + + backend = RedisBackend() + backend._client = mock_client + backend._script_sha = "sha" + + with pytest.raises(BudgetConfigMismatchError): + backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + +def test_redis_backend_get_state(): + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + mock_client.hgetall.return_value = { + b"usd:spent": b"2.34", + b"llm_calls:spent": b"45", + } + + backend = RedisBackend() + backend._client = mock_client + + state = backend.get_state("api") + assert state["usd"] == pytest.approx(2.34) + assert state["llm_calls"] == pytest.approx(45.0) + + +def test_redis_backend_reset_deletes_key(): + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + backend = RedisBackend() + backend._client = mock_client + + backend.reset("api") + mock_client.delete.assert_called_once() + + +def test_redis_backend_close(): + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + backend = RedisBackend() + backend._client = mock_client + + backend.close() + mock_client.close.assert_called_once() + + +# --------------------------------------------------------------------------- +# Group H — AsyncRedisBackend unit tests (mocked async redis) +# --------------------------------------------------------------------------- + + +def test_async_redis_backend_importable(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + assert AsyncRedisBackend is not None + + +def test_async_redis_backend_check_and_add_allowed(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.return_value = [1, b""] + + backend = AsyncRedisBackend() + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + allowed, exceeded = await backend.check_and_add( + "api", + amounts={"usd": 0.01}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + + asyncio.run(_run()) + + +def test_async_redis_backend_check_and_add_rejected(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.script_load.return_value = "sha" + mock_client.evalsha.return_value = [0, b"usd"] + + backend = AsyncRedisBackend() + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + allowed, exceeded = await backend.check_and_add( + "api", + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + + asyncio.run(_run()) + + +def test_async_redis_backend_fail_closed_on_error(): + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import AsyncRedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = AsyncMock() + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = AsyncRedisBackend(on_unavailable="closed") + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + with pytest.raises(BudgetExceededError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + asyncio.run(_run()) + + +def test_async_redis_backend_get_state(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.hgetall.return_value = {b"usd:spent": b"1.5"} + + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + state = await backend.get_state("api") + assert state["usd"] == pytest.approx(1.5) + + asyncio.run(_run()) + + +def test_async_redis_backend_reset(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + await backend.reset("api") + mock_client.delete.assert_called_once() + + asyncio.run(_run()) + + +def test_async_redis_backend_close(): + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + await backend.close() + mock_client.aclose.assert_called_once() + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# Group I — Coverage completeness: _temporal.py missing branches +# --------------------------------------------------------------------------- + + +def test_temporal_budget_no_window_seconds_raises(): + """kwargs form without window_seconds should raise ValueError.""" + from shekel._temporal import TemporalBudget + + with pytest.raises(ValueError, match="window_seconds"): + TemporalBudget(max_usd=5.0, name="no_ws") + + +def test_temporal_budget_tool_calls_cap_kwarg(): + """max_tool_calls kwarg builds 'tool_calls' cap.""" + from shekel._temporal import TemporalBudget + + tb = TemporalBudget(max_tool_calls=10, window_seconds=3600, name="tc") + assert "tool_calls" in tb._caps + + +def test_temporal_budget_no_caps_raises(): + """kwargs form with no cap at all should raise ValueError.""" + from shekel._temporal import TemporalBudget + + with pytest.raises(ValueError, match="cap"): + TemporalBudget(window_seconds=3600, name="no_cap") + + +def test_lazy_window_reset_skips_backend_without_get_window_info(): + """_lazy_window_reset returns early if backend lacks get_window_info.""" + from unittest.mock import MagicMock + + from shekel._temporal import TemporalBudget + + # spec=[] means the mock has NO attributes — hasattr returns False. + mock_backend = MagicMock(spec=[]) + tb = TemporalBudget(max_usd=5.0, window_seconds=3600, name="no_winfo", backend=mock_backend) + # __enter__ calls _lazy_window_reset — should not raise. + with tb: + pass + + +def test_lazy_window_reset_skips_when_window_start_none(): + """_lazy_window_reset returns early when window_start is None.""" + from shekel._temporal import InMemoryBackend, TemporalBudget + + backend = InMemoryBackend() + tb = TemporalBudget(max_usd=5.0, window_seconds=3600, name="wstart_none", backend=backend) + # Inject state with window_start=None (fresh counter not yet started). + backend._state["wstart_none"] = {"usd": (2.0, None)} + # __enter__ → _lazy_window_reset → window_start is None → early return. + with tb: + pass + + +# --------------------------------------------------------------------------- +# Group J — Coverage completeness: redis.py missing branches (mocked) +# --------------------------------------------------------------------------- + + +def test_emit_unavailable_swallows_emit_exception(): + """_emit_unavailable must not propagate when AdapterRegistry.emit_event raises.""" + pytest.importorskip("redis") + from unittest.mock import patch + + from shekel.backends.redis import _emit_unavailable + + with patch( + "shekel.integrations.AdapterRegistry.emit_event", + side_effect=RuntimeError("registry exploded"), + ): + _emit_unavailable("api", Exception("test")) # must not raise + + +def test_redis_backend_tls_passes_ssl_kwarg(): + """RedisBackend(tls=True) passes ssl=True to Redis.from_url.""" + pytest.importorskip("redis") + from unittest.mock import patch + + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + + with patch.object(redis_lib.Redis, "from_url", return_value=MagicMock()) as mock_from_url: + backend = RedisBackend(tls=True) + backend._ensure_client() + _, kwargs = mock_from_url.call_args + assert kwargs.get("ssl") is True + + +def test_redis_backend_get_state_returns_empty_on_redis_error(): + """get_state returns {} when hgetall raises.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + mock_client.hgetall.side_effect = redis_lib.RedisError("timeout") + backend = RedisBackend() + backend._client = mock_client + + assert backend.get_state("api") == {} + + +def test_redis_backend_get_state_skips_non_numeric_value(): + """get_state silently skips fields that can't be parsed as float.""" + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + mock_client.hgetall.return_value = {b"usd:spent": b"not_a_number"} + backend = RedisBackend() + backend._client = mock_client + + assert backend.get_state("api") == {} + + +def test_async_redis_backend_tls_passes_ssl_kwarg(): + """AsyncRedisBackend(tls=True) passes ssl=True to aioredis.Redis.from_url.""" + pytest.importorskip("redis") + from unittest.mock import patch + + import redis.asyncio as aioredis + + from shekel.backends.redis import AsyncRedisBackend + + async def _run() -> None: + with patch.object(aioredis.Redis, "from_url", return_value=AsyncMock()) as mock_from_url: + backend = AsyncRedisBackend(tls=True) + await backend._ensure_client() + _, kwargs = mock_from_url.call_args + assert kwargs.get("ssl") is True + + asyncio.run(_run()) + + +def test_async_redis_backend_config_mismatch_raises(): + """evalsha returning -2 raises BudgetConfigMismatchError (async path).""" + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + from shekel.exceptions import BudgetConfigMismatchError + + mock_client = AsyncMock() + mock_client.evalsha.return_value = [-2, b"spec_mismatch"] + + backend = AsyncRedisBackend() + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + with pytest.raises(BudgetConfigMismatchError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + asyncio.run(_run()) + + +def test_async_redis_backend_fail_open_when_configured(): + """AsyncRedisBackend with on_unavailable='open' allows calls through on error.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = AsyncRedisBackend(on_unavailable="open") + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + allowed, exceeded = await backend.check_and_add( + "api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0} + ) + assert allowed is True + assert exceeded is None + + asyncio.run(_run()) + + +def test_async_redis_backend_circuit_breaker_opens_and_stops_redis(): + """After N errors the async circuit opens and Redis is no longer called.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import AsyncRedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = AsyncMock() + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = AsyncRedisBackend( + on_unavailable="closed", + circuit_breaker_threshold=3, + circuit_breaker_cooldown=10.0, + ) + backend._client = mock_client + backend._script_sha = "sha" + + async def _run() -> None: + for _ in range(3): + with pytest.raises(BudgetExceededError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + calls_before = mock_client.evalsha.call_count + + with pytest.raises(BudgetExceededError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + assert mock_client.evalsha.call_count == calls_before + + asyncio.run(_run()) + + +def test_async_redis_backend_circuit_breaker_resets_after_cooldown(): + """Async circuit breaker allows retry after cooldown.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import AsyncRedisBackend + from shekel.exceptions import BudgetExceededError + + mock_client = AsyncMock() + mock_client.evalsha.side_effect = redis_lib.RedisError("timeout") + + backend = AsyncRedisBackend( + on_unavailable="closed", + circuit_breaker_threshold=3, + circuit_breaker_cooldown=10.0, + ) + backend._client = mock_client + backend._script_sha = "sha" + + t0 = 1000.0 + + async def _run() -> None: + with patch("time.monotonic", return_value=t0): + for _ in range(3): + with pytest.raises(BudgetExceededError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + calls_before = mock_client.evalsha.call_count + + with patch("time.monotonic", return_value=t0 + 11.0): + with pytest.raises(BudgetExceededError): + await backend.check_and_add("api", {"usd": 0.01}, {"usd": 5.0}, {"usd": 3600.0}) + + assert mock_client.evalsha.call_count > calls_before + + asyncio.run(_run()) + + +def test_async_redis_backend_get_state_returns_empty_on_error(): + """async get_state returns {} when hgetall raises.""" + pytest.importorskip("redis") + import redis as redis_lib + + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.hgetall.side_effect = redis_lib.RedisError("timeout") + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + assert await backend.get_state("api") == {} + + asyncio.run(_run()) + + +def test_async_redis_backend_get_state_skips_non_numeric_value(): + """async get_state silently skips fields that can't be parsed as float.""" + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.hgetall.return_value = {b"usd:spent": b"not_a_number"} + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + assert await backend.get_state("api") == {} + + asyncio.run(_run()) + + +def test_redis_backend_ensure_script_loads_on_first_call(): + """_ensure_script loads the Lua script when _script_sha is None.""" + pytest.importorskip("redis") + from shekel.backends.redis import RedisBackend + + mock_client = MagicMock() + mock_client.script_load.return_value = "loaded_sha" + backend = RedisBackend() + backend._client = mock_client + # _script_sha starts as None — force the load path + assert backend._script_sha is None + sha = backend._ensure_script() + assert sha == "loaded_sha" + assert backend._script_sha == "loaded_sha" + mock_client.script_load.assert_called_once() + + +def test_async_redis_backend_ensure_script_loads_on_first_call(): + """AsyncRedisBackend._ensure_script loads the Lua script when _script_sha is None.""" + pytest.importorskip("redis") + from shekel.backends.redis import AsyncRedisBackend + + mock_client = AsyncMock() + mock_client.script_load.return_value = "loaded_sha" + backend = AsyncRedisBackend() + backend._client = mock_client + + async def _run() -> None: + assert backend._script_sha is None + sha = await backend._ensure_script() + assert sha == "loaded_sha" + assert backend._script_sha == "loaded_sha" + mock_client.script_load.assert_called_once() + + asyncio.run(_run()) + + +def test_lazy_window_reset_skips_when_window_not_yet_expired(): + """_lazy_window_reset returns early if the primary window has not yet expired.""" + from unittest.mock import patch + + from shekel._temporal import InMemoryBackend, TemporalBudget + + backend = InMemoryBackend() + tb = TemporalBudget(max_usd=5.0, window_seconds=3600, name="no_expire", backend=backend) + t0 = 1000.0 + + with patch("time.monotonic", return_value=t0): + backend.check_and_add("no_expire", {"usd": 1.0}, {"usd": 5.0}, {"usd": 3600.0}) + + # Only 100 s elapsed — window has NOT expired (3600 s window) + with patch("time.monotonic", return_value=t0 + 100.0): + tb._lazy_window_reset() # must return early without emitting diff --git a/tests/test_langchain_wrappers.py b/tests/test_langchain_wrappers.py new file mode 100644 index 0000000..20197cf --- /dev/null +++ b/tests/test_langchain_wrappers.py @@ -0,0 +1,712 @@ +"""Tests for LangChain chain-level budget enforcement. + +Domain: LangChainRunnerAdapter — patching, chain gate, spend attribution, async support. +""" + +from __future__ import annotations + +import sys +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import shekel.providers.langchain as lc_mod +from shekel import budget +from shekel._budget import Budget +from shekel._runtime import ShekelRuntime +from shekel.exceptions import BudgetExceededError, ChainBudgetExceededError +from shekel.providers.langchain import LangChainRunnerAdapter + +try: + from langchain_core.runnables.base import Runnable, RunnableLambda, RunnableSequence + + LANGCHAIN_AVAILABLE = True +except ImportError: + LANGCHAIN_AVAILABLE = False + +pytestmark = pytest.mark.skipif(not LANGCHAIN_AVAILABLE, reason="langchain_core not installed") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def restore_adapter_state(): + """Restore LangChainRunnerAdapter patch state and ShekelRuntime registry after each test.""" + original_refcount = lc_mod._chain_patch_refcount + original_cwc = lc_mod._original_call_with_config + original_acwc = lc_mod._original_acall_with_config + original_seq_invoke = lc_mod._original_sequence_invoke + original_seq_ainvoke = lc_mod._original_sequence_ainvoke + original_registry = ShekelRuntime._adapter_registry[:] + + real_cwc = Runnable._call_with_config + real_acwc = Runnable._acall_with_config + real_seq_invoke = RunnableSequence.invoke + real_seq_ainvoke = RunnableSequence.ainvoke + + yield + + lc_mod._chain_patch_refcount = original_refcount + lc_mod._original_call_with_config = original_cwc + lc_mod._original_acall_with_config = original_acwc + lc_mod._original_sequence_invoke = original_seq_invoke + lc_mod._original_sequence_ainvoke = original_seq_ainvoke + + Runnable._call_with_config = real_cwc # type: ignore[method-assign] + Runnable._acall_with_config = real_acwc # type: ignore[method-assign] + RunnableSequence.invoke = real_seq_invoke # type: ignore[method-assign] + RunnableSequence.ainvoke = real_seq_ainvoke # type: ignore[method-assign] + + ShekelRuntime._adapter_registry = original_registry + + +# --------------------------------------------------------------------------- +# Group 1: LangChainRunnerAdapter registered in ShekelRuntime +# --------------------------------------------------------------------------- + + +def test_langchain_runner_adapter_in_runtime_registry() -> None: + """LangChainRunnerAdapter is registered in ShekelRuntime at import time.""" + assert LangChainRunnerAdapter in ShekelRuntime._adapter_registry + + +def test_langchain_runner_adapter_registered_exactly_once() -> None: + """LangChainRunnerAdapter appears exactly once in the registry.""" + count = sum(1 for a in ShekelRuntime._adapter_registry if a is LangChainRunnerAdapter) + assert count == 1 + + +# --------------------------------------------------------------------------- +# Group 2: install_patches / remove_patches lifecycle +# --------------------------------------------------------------------------- + + +def test_install_patches_replaces_call_with_config() -> None: + """install_patches() patches Runnable._call_with_config.""" + original = Runnable._call_with_config + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + assert Runnable._call_with_config is not original + adapter.remove_patches(b) + + +def test_install_patches_replaces_sequence_invoke() -> None: + """install_patches() patches RunnableSequence.invoke.""" + original = RunnableSequence.invoke + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + assert RunnableSequence.invoke is not original + adapter.remove_patches(b) + + +def test_install_patches_raises_import_error_when_langchain_absent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """install_patches() raises ImportError when langchain_core is not importable.""" + monkeypatch.setitem(sys.modules, "langchain_core", None) + monkeypatch.setitem(sys.modules, "langchain_core.runnables", None) + monkeypatch.setitem(sys.modules, "langchain_core.runnables.base", None) + adapter = LangChainRunnerAdapter() + with pytest.raises(ImportError): + adapter.install_patches(Budget(max_usd=5.00)) + + +def test_remove_patches_restores_call_with_config() -> None: + """remove_patches() restores Runnable._call_with_config.""" + original = Runnable._call_with_config + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + adapter.remove_patches(b) + assert Runnable._call_with_config is original + + +def test_remove_patches_restores_sequence_invoke() -> None: + """remove_patches() restores RunnableSequence.invoke.""" + original = RunnableSequence.invoke + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + adapter.remove_patches(b) + assert RunnableSequence.invoke is original + + +def test_reference_counting_patch_applied_once_for_nested_budgets() -> None: + """Nested budgets increment refcount but only patch once.""" + b1 = Budget(max_usd=5.00) + b2 = Budget(max_usd=3.00) + a1 = LangChainRunnerAdapter() + a2 = LangChainRunnerAdapter() + + a1.install_patches(b1) + assert lc_mod._chain_patch_refcount == 1 + + a2.install_patches(b2) + assert lc_mod._chain_patch_refcount == 2 + + a2.remove_patches(b2) + assert lc_mod._chain_patch_refcount == 1 + + a1.remove_patches(b1) + assert lc_mod._chain_patch_refcount == 0 + + +def test_remove_patches_is_safe_at_zero_refcount() -> None: + """remove_patches() is a no-op when refcount is already 0.""" + adapter = LangChainRunnerAdapter() + lc_mod._chain_patch_refcount = 0 + adapter.remove_patches(Budget(max_usd=5.00)) + assert lc_mod._chain_patch_refcount == 0 + + +def test_patches_restored_on_budget_exit_even_on_exception() -> None: + """Runnable methods are restored even when an exception propagates.""" + original_cwc = Runnable._call_with_config + raised = False + try: + with budget(max_usd=5.00): + assert Runnable._call_with_config is not original_cwc + raise ValueError("simulated error") + except ValueError: + raised = True + assert raised + assert Runnable._call_with_config is original_cwc + + +# --------------------------------------------------------------------------- +# Group 3: Budget.chain() API +# --------------------------------------------------------------------------- + + +def test_budget_chain_registers_component_budget() -> None: + """b.chain() registers a ComponentBudget in _chain_budgets.""" + b = Budget(max_usd=5.00) + b.chain("summarize", max_usd=0.50) + assert "summarize" in b._chain_budgets + assert b._chain_budgets["summarize"].max_usd == 0.50 + + +def test_chain_method_returns_self_for_chaining() -> None: + """b.chain() returns self for method chaining.""" + b = Budget(max_usd=5.00) + result = b.chain("a", max_usd=0.10).chain("b", max_usd=0.20) + assert result is b + assert "a" in b._chain_budgets + assert "b" in b._chain_budgets + + +def test_chain_max_usd_must_be_positive() -> None: + """b.chain() raises ValueError when max_usd <= 0.""" + b = Budget(max_usd=5.00) + with pytest.raises(ValueError): + b.chain("step", max_usd=0.0) + with pytest.raises(ValueError): + b.chain("step", max_usd=-1.0) + + +def test_chain_budgets_accessible_outside_context_manager() -> None: + """b.chain() can be called before __enter__ — registration is separate from activation.""" + b = Budget(max_usd=5.00) + b.chain("step", max_usd=0.30) + assert b._chain_budgets["step"]._spent == 0.0 + + +# --------------------------------------------------------------------------- +# Group 4: Pre-execution gate — explicit chain cap (RunnableLambda) +# --------------------------------------------------------------------------- + + +def test_explicit_chain_cap_exceeded_raises_before_lambda_runs() -> None: + """ChainBudgetExceededError raised BEFORE the RunnableLambda body executes.""" + executed: list[bool] = [] + + def my_fn(x: Any) -> Any: + executed.append(True) + return x + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 # exhaust the cap + chain = RunnableLambda(my_fn, name="my_fn") + + with pytest.raises(ChainBudgetExceededError): + chain.invoke("input") + + assert executed == [] + + +def test_explicit_chain_cap_error_carries_correct_fields() -> None: + """ChainBudgetExceededError has correct chain_name, spent, limit.""" + + def my_fn(x: Any) -> Any: + return x + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 + chain = RunnableLambda(my_fn, name="my_fn") + + with pytest.raises(ChainBudgetExceededError) as exc_info: + chain.invoke("input") + + err = exc_info.value + assert err.chain_name == "my_fn" + assert err.spent == pytest.approx(0.10) + assert err.limit == pytest.approx(0.10) + + +def test_chain_cap_not_exceeded_allows_lambda_to_run() -> None: + """Lambda runs normally when spend is below chain cap.""" + + def my_fn(x: Any) -> Any: + return "result" + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.50) + b._chain_budgets["my_fn"]._spent = 0.05 # below cap + chain = RunnableLambda(my_fn, name="my_fn") + result = chain.invoke("input") + + assert result == "result" + + +def test_unnamed_lambda_not_gated() -> None: + """Unnamed RunnableLambda is not gated even with chain caps registered.""" + executed: list[bool] = [] + + def my_fn(x: Any) -> Any: + executed.append(True) + return x + + with budget(max_usd=5.00) as b: + b.chain("other", max_usd=0.10) + b._chain_budgets["other"]._spent = 0.10 # other cap exhausted + # No name — should NOT be gated + chain = RunnableLambda(my_fn) + result = chain.invoke("input") + + assert executed == [True] + assert result == "input" + + +def test_chain_budget_exceeded_is_subclass_of_budget_exceeded_error() -> None: + """ChainBudgetExceededError is catchable as BudgetExceededError.""" + + def my_fn(x: Any) -> Any: + return x + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 + chain = RunnableLambda(my_fn, name="my_fn") + + with pytest.raises(BudgetExceededError): + chain.invoke("input") + + +# --------------------------------------------------------------------------- +# Group 5: Pre-execution gate — parent budget exhaustion +# --------------------------------------------------------------------------- + + +def test_parent_budget_exhausted_raises_chain_budget_exceeded_error() -> None: + """ChainBudgetExceededError raised when parent budget is at limit.""" + executed: list[bool] = [] + + def my_fn(x: Any) -> Any: + executed.append(True) + return x + + with budget(max_usd=1.00) as b: + b.chain("my_fn", max_usd=5.00) # chain cap not exceeded + b._spent = 1.00 # parent exhausted + chain = RunnableLambda(my_fn, name="my_fn") + + with pytest.raises(ChainBudgetExceededError) as exc_info: + chain.invoke("input") + + assert executed == [] + assert exc_info.value.chain_name == "my_fn" + + +def test_no_active_budget_lambda_runs_unguarded() -> None: + """Lambda executes normally when invoked outside any budget context.""" + + def my_fn(x: Any) -> Any: + return "unguarded" + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 + chain = RunnableLambda(my_fn, name="my_fn") + + # Invoke OUTSIDE budget — get_active_budget() returns None + result = chain.invoke("input") + assert result == "unguarded" + + +# --------------------------------------------------------------------------- +# Group 6: Post-execution spend attribution (RunnableLambda) +# --------------------------------------------------------------------------- + + +def test_spend_delta_attributed_to_chain_component_budget() -> None: + """Spend during lambda execution is attributed to its chain ComponentBudget._spent.""" + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=1.00) + + def my_fn(x: Any) -> Any: + b._spent += 0.15 # simulate LLM call + return x + + chain = RunnableLambda(my_fn, name="my_fn") + chain.invoke("input") + cb = b._chain_budgets["my_fn"] + + assert cb._spent == pytest.approx(0.15) + + +def test_zero_spend_in_lambda_leaves_chain_budget_at_zero() -> None: + """Lambda with no LLM spend leaves chain ComponentBudget._spent at 0.""" + + def my_fn(x: Any) -> Any: + return x + + with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.50) + chain = RunnableLambda(my_fn, name="my_fn") + chain.invoke("input") + cb = b._chain_budgets["my_fn"] + + assert cb._spent == pytest.approx(0.0) + + +def test_spend_not_attributed_when_no_cap_registered_for_name() -> None: + """Named lambda without a registered cap doesn't populate _chain_budgets.""" + + def my_fn(x: Any) -> Any: + b._spent += 0.05 + return x + + with budget(max_usd=5.00) as b: + chain = RunnableLambda(my_fn, name="my_fn") + chain.invoke("input") + + assert "my_fn" not in b._chain_budgets + + +# --------------------------------------------------------------------------- +# Group 7: RunnableSequence (LCEL pipeline) cap enforcement +# --------------------------------------------------------------------------- + + +def test_sequence_cap_exceeded_raises_before_pipeline_runs() -> None: + """ChainBudgetExceededError raised before a capped RunnableSequence executes.""" + executed: list[bool] = [] + + def step1(x: Any) -> Any: + executed.append(True) + return x + + def step2(x: Any) -> Any: + executed.append(True) + return x + + with budget(max_usd=5.00) as b: + b.chain("my_pipeline", max_usd=0.10) + b._chain_budgets["my_pipeline"]._spent = 0.10 + + seq = RunnableLambda(step1) | RunnableLambda(step2) + seq.name = "my_pipeline" # type: ignore[assignment] + + with pytest.raises(ChainBudgetExceededError): + seq.invoke("input") + + assert executed == [] + + +def test_sequence_cap_not_exceeded_allows_pipeline_to_run() -> None: + """RunnableSequence runs normally when below its chain cap.""" + + def step1(x: Any) -> Any: + return x + "_s1" + + def step2(x: Any) -> Any: + return x + "_s2" + + with budget(max_usd=5.00) as b: + b.chain("my_pipeline", max_usd=0.50) + seq = RunnableLambda(step1) | RunnableLambda(step2) + seq.name = "my_pipeline" # type: ignore[assignment] + result = seq.invoke("input") + + assert result == "input_s1_s2" + + +def test_sequence_spend_attributed_to_chain_budget() -> None: + """Spend during RunnableSequence execution attributed to its chain ComponentBudget.""" + with budget(max_usd=5.00) as b: + b.chain("my_pipeline", max_usd=1.00) + + def step1(x: Any) -> Any: + b._spent += 0.20 + return x + + seq = RunnableLambda(step1) + seq.name = "my_pipeline" # type: ignore[assignment] + seq.invoke("input") + + assert b._chain_budgets["my_pipeline"]._spent == pytest.approx(0.20) + + +# --------------------------------------------------------------------------- +# Group 8: Nested budget cap inheritance (parent-chain lookup) +# --------------------------------------------------------------------------- + + +def test_chain_cap_on_outer_budget_enforced_in_nested_inner_budget() -> None: + """Cap registered on outer budget raises ChainBudgetExceededError inside inner context.""" + executed: list[bool] = [] + + def my_fn(x: Any) -> Any: + executed.append(True) + return x + + with budget(max_usd=5.00, name="outer") as outer: + outer.chain("my_fn", max_usd=0.10) + outer._chain_budgets["my_fn"]._spent = 0.10 + + with budget(max_usd=2.00, name="inner"): + chain = RunnableLambda(my_fn, name="my_fn") + with pytest.raises(ChainBudgetExceededError): + chain.invoke("input") + + assert executed == [] + + +def test_inner_budget_chain_cap_takes_precedence_over_outer() -> None: + """Cap registered on inner budget is used even if outer also has a cap for same name.""" + + def my_fn(x: Any) -> Any: + return x + + with budget(max_usd=5.00, name="outer") as outer: + outer.chain("my_fn", max_usd=1.00) # outer cap: $1.00 — not exceeded + + with budget(max_usd=2.00, name="inner") as inner: + inner.chain("my_fn", max_usd=0.05) + inner._chain_budgets["my_fn"]._spent = 0.05 # exhaust inner cap + + chain = RunnableLambda(my_fn, name="my_fn") + with pytest.raises(ChainBudgetExceededError) as exc_info: + chain.invoke("input") + + assert exc_info.value.limit == pytest.approx(0.05) + + +# --------------------------------------------------------------------------- +# Group 9: Async support +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_lambda_cap_exceeded_raises_before_execution() -> None: + """Async ChainBudgetExceededError raised before awaiting lambda body.""" + executed: list[bool] = [] + + async def my_fn(x: Any) -> Any: + executed.append(True) + return x + + async with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 + chain = RunnableLambda(my_fn, name="my_fn") + + with pytest.raises(ChainBudgetExceededError): + await chain.ainvoke("input") + + assert executed == [] + + +@pytest.mark.asyncio +async def test_async_lambda_runs_and_returns_result() -> None: + """Async lambda runs normally when below cap.""" + + async def my_fn(x: Any) -> Any: + return "async_result" + + async with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=1.00) + chain = RunnableLambda(my_fn, name="my_fn") + result = await chain.ainvoke("input") + + assert result == "async_result" + + +@pytest.mark.asyncio +async def test_async_spend_attributed_to_chain_budget() -> None: + """Spend during async lambda execution attributed to chain ComponentBudget.""" + + async with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=1.00) + + async def my_fn(x: Any) -> Any: + b._spent += 0.25 + return x + + chain = RunnableLambda(my_fn, name="my_fn") + await chain.ainvoke("input") + + assert b._chain_budgets["my_fn"]._spent == pytest.approx(0.25) + + +@pytest.mark.asyncio +async def test_async_no_active_budget_runs_unguarded() -> None: + """Async lambda executes normally outside any budget context.""" + + async def my_fn(x: Any) -> Any: + return "unguarded_async" + + async with budget(max_usd=5.00) as b: + b.chain("my_fn", max_usd=0.10) + b._chain_budgets["my_fn"]._spent = 0.10 + chain = RunnableLambda(my_fn, name="my_fn") + + result = await chain.ainvoke("input") + assert result == "unguarded_async" + + +# --------------------------------------------------------------------------- +# Group 10: Integration with Budget lifecycle via ShekelRuntime +# --------------------------------------------------------------------------- + + +def test_call_with_config_patched_on_budget_enter() -> None: + """Runnable._call_with_config is patched when a budget context is entered.""" + original = Runnable._call_with_config + with budget(max_usd=5.00): + assert Runnable._call_with_config is not original + assert Runnable._call_with_config is original + + +def test_call_with_config_restored_on_budget_exit() -> None: + """Runnable._call_with_config is restored after the budget context exits.""" + original = Runnable._call_with_config + with budget(max_usd=5.00): + pass + assert Runnable._call_with_config is original + + +def test_named_lambda_passthrough_when_patch_active_but_no_budget() -> None: + """Named lambda runs unguarded when patch is installed but no budget context is active.""" + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + try: + chain = RunnableLambda(lambda x: "passthrough", name="my_fn") + result = chain.invoke("input") # no active budget — passthrough path + assert result == "passthrough" + finally: + adapter.remove_patches(b) + + +@pytest.mark.asyncio +async def test_named_lambda_async_passthrough_when_no_budget() -> None: + """Async named lambda runs unguarded when patch is installed but no budget context is active.""" + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + try: + + async def my_fn(x: Any) -> Any: + return "async_passthrough" + + chain = RunnableLambda(my_fn, name="my_fn") + result = await chain.ainvoke("input") + assert result == "async_passthrough" + finally: + adapter.remove_patches(b) + + +def test_sequence_passthrough_when_patch_active_but_no_budget() -> None: + """Named RunnableSequence runs unguarded when patch is installed but no budget context.""" + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + try: + # Use | to create an actual RunnableSequence + seq = RunnableLambda(lambda x: x) | RunnableLambda(lambda x: x + "_done") + seq.name = "my_seq" # type: ignore[assignment] + result = seq.invoke("input") # no active budget — passthrough path + assert result == "input_done" + finally: + adapter.remove_patches(b) + + +@pytest.mark.asyncio +async def test_async_sequence_cap_enforced_and_spend_attributed() -> None: + """Async RunnableSequence (created with |) cap enforced and spend attributed.""" + async with budget(max_usd=5.00) as b: + b.chain("async_pipeline", max_usd=1.00) + + seq = RunnableLambda(lambda x: x) | RunnableLambda(lambda x: x + "_done") + seq.name = "async_pipeline" # type: ignore[assignment] + result = await seq.ainvoke("input") + + assert result == "input_done" + + +@pytest.mark.asyncio +async def test_sequence_async_passthrough_when_patch_active_but_no_budget() -> None: + """Named RunnableSequence async runs unguarded when patch is installed but no budget context.""" + adapter = LangChainRunnerAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + try: + # Use | to create an actual RunnableSequence + seq = RunnableLambda(lambda x: x) | RunnableLambda(lambda x: x + "_async_done") + seq.name = "my_seq" # type: ignore[assignment] + result = await seq.ainvoke("input") + assert result == "input_async_done" + finally: + adapter.remove_patches(b) + + +def test_mock_llm_spend_tracked_with_chain_cap() -> None: + """End-to-end: mocked LLM call inside lambda → spend tracked in chain ComponentBudget.""" + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "ok" + mock_resp.usage.prompt_tokens = 100 + mock_resp.usage.completion_tokens = 50 + mock_resp.model = "gpt-4o-mini" + + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + import openai + + client = openai.OpenAI(api_key="test") + + with budget(max_usd=5.00) as b: + b.chain("llm_step", max_usd=1.00) + + def llm_step(x: Any) -> Any: + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hello"}], + ) + return x + + chain = RunnableLambda(llm_step, name="llm_step") + chain.invoke("input") + + assert b._chain_budgets["llm_step"]._spent > 0 + assert b._chain_budgets["llm_step"]._spent == pytest.approx(b.spent) diff --git a/tests/test_langgraph_wrappers.py b/tests/test_langgraph_wrappers.py new file mode 100644 index 0000000..9968dbb --- /dev/null +++ b/tests/test_langgraph_wrappers.py @@ -0,0 +1,794 @@ +"""Tests for LangGraph node-level budget enforcement (v0.3.1). + +Domain: lg_mod.LangGraphAdapter — patching, node gate, spend attribution, async support. +""" + +from __future__ import annotations + +import inspect +import sys +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import shekel.providers.langgraph as lg_mod +from shekel import budget +from shekel._budget import Budget +from shekel._runtime import ShekelRuntime +from shekel.exceptions import BudgetExceededError, NodeBudgetExceededError + +try: + from langgraph.graph import END, StateGraph + from typing_extensions import TypedDict + + LANGGRAPH_AVAILABLE = True + + class _State(TypedDict): + value: int + +except ImportError: + LANGGRAPH_AVAILABLE = False + +pytestmark = pytest.mark.skipif(not LANGGRAPH_AVAILABLE, reason="langgraph not installed") + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def restore_adapter_state(): + """Restore lg_mod.LangGraphAdapter patch state and ShekelRuntime registry after each test.""" + original_refcount = lg_mod._patch_refcount + original_add_node = lg_mod._original_add_node + original_registry = ShekelRuntime._adapter_registry[:] + + # Capture the actual StateGraph.add_node before any test mutations + real_add_node = StateGraph.add_node + + yield + + # Restore module state + lg_mod._patch_refcount = original_refcount + lg_mod._original_add_node = original_add_node + + # Restore StateGraph + StateGraph.add_node = real_add_node # type: ignore[method-assign] + + # Restore registry + ShekelRuntime._adapter_registry = original_registry + + +def _make_simple_graph(node_name: str, node_fn: Any) -> Any: + """Build a minimal single-node StateGraph.""" + g = StateGraph(_State) + g.add_node(node_name, node_fn) + g.set_entry_point(node_name) + g.add_edge(node_name, END) + return g.compile() + + +# --------------------------------------------------------------------------- +# Group 1: lg_mod.LangGraphAdapter registered in ShekelRuntime +# --------------------------------------------------------------------------- + + +def test_langgraph_adapter_in_runtime_registry() -> None: + """lg_mod.LangGraphAdapter is registered in ShekelRuntime at import time.""" + assert lg_mod.LangGraphAdapter in ShekelRuntime._adapter_registry + + +def test_langgraph_adapter_is_registered_exactly_once() -> None: + """lg_mod.LangGraphAdapter appears exactly once in the registry.""" + count = sum(1 for a in ShekelRuntime._adapter_registry if a is lg_mod.LangGraphAdapter) + assert count == 1 + + +# --------------------------------------------------------------------------- +# Group 2: install_patches / remove_patches lifecycle +# --------------------------------------------------------------------------- + + +def test_install_patches_replaces_add_node() -> None: + """install_patches() replaces StateGraph.add_node with the gated version.""" + original = StateGraph.add_node + adapter = lg_mod.LangGraphAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + assert StateGraph.add_node is not original + adapter.remove_patches(b) + + +def test_install_patches_raises_import_error_when_langgraph_absent( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """install_patches() raises ImportError when langgraph is not importable.""" + monkeypatch.setitem(sys.modules, "langgraph", None) + monkeypatch.setitem(sys.modules, "langgraph.graph", None) + monkeypatch.setitem(sys.modules, "langgraph.graph.state", None) + adapter = lg_mod.LangGraphAdapter() + with pytest.raises(ImportError): + adapter.install_patches(Budget(max_usd=5.00)) + + +def test_remove_patches_restores_add_node() -> None: + """remove_patches() restores StateGraph.add_node to the original.""" + original = StateGraph.add_node + adapter = lg_mod.LangGraphAdapter() + b = Budget(max_usd=5.00) + adapter.install_patches(b) + adapter.remove_patches(b) + assert StateGraph.add_node is original + + +def test_reference_counting_patch_applied_once_for_nested_budgets() -> None: + """Nested budgets increment refcount but only patch add_node once.""" + original = StateGraph.add_node + b1 = Budget(max_usd=5.00) + b2 = Budget(max_usd=3.00) + a1 = lg_mod.LangGraphAdapter() + a2 = lg_mod.LangGraphAdapter() + + a1.install_patches(b1) + patched = StateGraph.add_node + assert patched is not original + assert lg_mod._patch_refcount == 1 + + a2.install_patches(b2) + assert StateGraph.add_node is patched # still the same patch + assert lg_mod._patch_refcount == 2 + + a2.remove_patches(b2) + assert StateGraph.add_node is patched # still patched (refcount 1) + assert lg_mod._patch_refcount == 1 + + a1.remove_patches(b1) + assert StateGraph.add_node is original # fully restored + assert lg_mod._patch_refcount == 0 + + +def test_remove_patches_is_safe_at_zero_refcount() -> None: + """remove_patches() is a no-op when refcount is already 0.""" + adapter = lg_mod.LangGraphAdapter() + lg_mod._patch_refcount = 0 + adapter.remove_patches(Budget(max_usd=5.00)) # must not raise + assert lg_mod._patch_refcount == 0 + + +def test_remove_patches_is_safe_when_original_is_none() -> None: + """remove_patches() is a no-op when _original_add_node is None (refcount 1→0).""" + adapter = lg_mod.LangGraphAdapter() + lg_mod._patch_refcount = 1 + lg_mod._original_add_node = None # simulate missing original + adapter.remove_patches(Budget(max_usd=5.00)) # must not raise + assert lg_mod._patch_refcount == 0 + + +def test_add_node_callable_action_gets_wrapped() -> None: + """Compiled subgraph (callable) passed as action is wrapped with the gate.""" + with budget(max_usd=5.00): + # Build a subgraph and pass it as the action + sub = StateGraph(_State) + sub.add_node("inner", lambda s: {"value": s["value"] + 1}) + sub.set_entry_point("inner") + sub.add_edge("inner", END) + compiled_sub = sub.compile() + + # compiled_sub is callable — should be wrapped, not passed through + wrapped = lg_mod._make_gate(compiled_sub, "sub") + assert callable(wrapped) + assert wrapped.__wrapped__ is compiled_sub or hasattr(wrapped, "__wrapped__") + + +def test_add_node_restored_on_budget_exit_even_on_exception() -> None: + """StateGraph.add_node is restored even when an exception propagates.""" + original = StateGraph.add_node + try: + with budget(max_usd=5.00): + patched = StateGraph.add_node + assert patched is not original + raise ValueError("simulated error") + except ValueError: + pass + assert StateGraph.add_node is original + + +# --------------------------------------------------------------------------- +# Group 3: Node wrapping — name resolution and callable detection +# --------------------------------------------------------------------------- + + +def test_add_node_name_fn_form_wraps_action() -> None: + """add_node('name', fn) wraps fn; node body is reachable via the wrapper.""" + called: list[bool] = [] + + def my_node(state: _State) -> dict: + called.append(True) + return {"value": state["value"] + 1} + + with budget(max_usd=5.00): + app = _make_simple_graph("my_node", my_node) + result = app.invoke({"value": 0}) + + assert result["value"] == 1 + assert called == [True] + + +def test_add_node_fn_form_wraps_callable() -> None: + """add_node(fn) form (no explicit name) wraps fn using fn.__name__.""" + called: list[bool] = [] + + def my_node(state: _State) -> dict: + called.append(True) + return {"value": 99} + + with budget(max_usd=5.00): + g = StateGraph(_State) + g.add_node(my_node) + g.set_entry_point("my_node") + g.add_edge("my_node", END) + app = g.compile() + app.invoke({"value": 0}) + + assert called == [True] + + +def test_functools_wraps_preserves_original_name() -> None: + """The gated wrapper has __name__ == original function __name__.""" + + def target(state: _State) -> dict: + return {"value": 0} + + wrapped = lg_mod._make_gate(target, "target") + assert wrapped.__name__ == "target" + + +def test_sync_node_wrapped_as_sync() -> None: + """A sync node function is wrapped with a sync (non-coroutine) gate.""" + + def sync_node(state: _State) -> dict: + return {"value": 0} + + wrapped = lg_mod._make_gate(sync_node, "sync_node") + assert not inspect.iscoroutinefunction(wrapped) + + +def test_async_node_wrapped_as_async() -> None: + """An async node function is wrapped with an async (coroutine) gate.""" + + async def async_node(state: _State) -> dict: + return {"value": 0} + + wrapped = lg_mod._make_gate(async_node, "async_node") + assert inspect.iscoroutinefunction(wrapped) + + +# --------------------------------------------------------------------------- +# Group 4: Pre-execution gate — explicit node cap +# --------------------------------------------------------------------------- + + +def test_explicit_node_cap_exceeded_raises_before_node_runs() -> None: + """NodeBudgetExceededError raised BEFORE the node body executes.""" + executed: list[bool] = [] + + def fetch(state: _State) -> dict: + executed.append(True) # must never be reached + return {"value": state["value"] + 1} + + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=0.10) + b._node_budgets["fetch"]._spent = 0.10 # exhaust the cap + app = _make_simple_graph("fetch", fetch) + + with pytest.raises(NodeBudgetExceededError): + app.invoke({"value": 0}) + + assert executed == [] # node body never ran + + +def test_explicit_node_cap_error_carries_correct_fields() -> None: + """NodeBudgetExceededError has correct node_name, spent, limit.""" + + def fetch(state: _State) -> dict: + return {"value": 0} + + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=0.10) + b._node_budgets["fetch"]._spent = 0.10 + app = _make_simple_graph("fetch", fetch) + + with pytest.raises(NodeBudgetExceededError) as exc_info: + app.invoke({"value": 0}) + + err = exc_info.value + assert err.node_name == "fetch" + assert err.spent == pytest.approx(0.10) + assert err.limit == pytest.approx(0.10) + + +def test_explicit_node_cap_not_exceeded_allows_node_to_run() -> None: + """Node runs normally when spend is below cap.""" + + def fetch(state: _State) -> dict: + return {"value": state["value"] + 10} + + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=0.50) + b._node_budgets["fetch"]._spent = 0.05 # below cap + app = _make_simple_graph("fetch", fetch) + result = app.invoke({"value": 0}) + + assert result["value"] == 10 + + +def test_node_cap_does_not_affect_other_nodes() -> None: + """Capping one node does not prevent other nodes from running.""" + ran: list[str] = [] + + def n1(state: _State) -> dict: + ran.append("n1") + return {"value": state["value"] + 1} + + def n2(state: _State) -> dict: + ran.append("n2") + return {"value": state["value"] + 1} + + with budget(max_usd=5.00) as b: + b.node("n1", max_usd=0.10) + b._node_budgets["n1"]._spent = 0.09 # below cap + # n2 has no cap + + g = StateGraph(_State) + g.add_node("n1", n1) + g.add_node("n2", n2) + g.set_entry_point("n1") + g.add_edge("n1", "n2") + g.add_edge("n2", END) + app = g.compile() + app.invoke({"value": 0}) + + assert ran == ["n1", "n2"] + + +# --------------------------------------------------------------------------- +# Group 5: Pre-execution gate — parent budget exhaustion +# --------------------------------------------------------------------------- + + +def test_parent_budget_exhausted_raises_node_budget_exceeded_error() -> None: + """NodeBudgetExceededError raised when parent budget is at limit.""" + executed: list[bool] = [] + + def process(state: _State) -> dict: + executed.append(True) + return {"value": 0} + + with budget(max_usd=1.00) as b: + b._spent = 1.00 # simulate exhausted parent + app = _make_simple_graph("process", process) + + with pytest.raises(NodeBudgetExceededError) as exc_info: + app.invoke({"value": 0}) + + assert executed == [] + assert exc_info.value.node_name == "process" + + +def test_parent_budget_not_exhausted_allows_node() -> None: + """Node runs when parent budget still has headroom.""" + + def process(state: _State) -> dict: + return {"value": 42} + + with budget(max_usd=1.00) as b: + b._spent = 0.50 # half used + app = _make_simple_graph("process", process) + result = app.invoke({"value": 0}) + + assert result["value"] == 42 + + +def test_track_only_budget_no_gate() -> None: + """Track-only budget (no max_usd) never triggers the parent gate.""" + + def process(state: _State) -> dict: + return {"value": 7} + + with budget() as b: + b._spent = 999.0 # high spend, no limit + app = _make_simple_graph("process", process) + result = app.invoke({"value": 0}) + + assert result["value"] == 7 + + +def test_node_budget_exceeded_is_subclass_of_budget_exceeded_error() -> None: + """NodeBudgetExceededError is catchable as BudgetExceededError.""" + + def process(state: _State) -> dict: + return {"value": 0} + + with budget(max_usd=1.00) as b: + b._spent = 1.00 + app = _make_simple_graph("process", process) + + with pytest.raises(BudgetExceededError): # catches NodeBudgetExceededError + app.invoke({"value": 0}) + + +def test_no_active_budget_node_runs_unguarded() -> None: + """Node executes normally when invoked outside any budget context.""" + + def process(state: _State) -> dict: + return {"value": 100} + + # Build inside a budget context (so add_node is patched) + with budget(max_usd=5.00): + app = _make_simple_graph("process", process) + + # Invoke OUTSIDE any budget — get_active_budget() returns None + result = app.invoke({"value": 0}) + assert result["value"] == 100 + + +# --------------------------------------------------------------------------- +# Group 6: Post-execution spend attribution +# --------------------------------------------------------------------------- + + +def test_spend_delta_attributed_to_component_budget() -> None: + """Spend during node execution is attributed to its ComponentBudget._spent.""" + + def fetch(state: _State) -> dict: + return {"value": state["value"]} + + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=1.00) + + # Simulate spend occurring during the node by patching _spent after capture + original_fn_ref: list[Any] = [] + original_fetch = fetch + + def fetch_with_spend(state: _State) -> dict: + # Simulate an LLM call adding $0.15 to the budget + b._spent += 0.15 + return original_fn_ref[0](state) + + original_fn_ref.append(original_fetch) + + app = _make_simple_graph("fetch", fetch_with_spend) + app.invoke({"value": 0}) + cb = b._node_budgets["fetch"] + + assert cb._spent == pytest.approx(0.15) + + +def test_zero_spend_in_node_leaves_component_budget_at_zero() -> None: + """When a node causes no LLM spend, ComponentBudget._spent stays 0.0.""" + + def cheap_node(state: _State) -> dict: + return {"value": 1} + + with budget(max_usd=5.00) as b: + b.node("cheap_node", max_usd=0.50) + app = _make_simple_graph("cheap_node", cheap_node) + app.invoke({"value": 0}) + cb = b._node_budgets["cheap_node"] + + assert cb._spent == pytest.approx(0.0) + + +def test_spend_not_attributed_when_no_cap_registered() -> None: + """Nodes without an explicit cap don't populate _node_budgets._spent.""" + + def process(state: _State) -> dict: + b._spent += 0.05 # simulated LLM call + return {"value": 0} + + with budget(max_usd=5.00) as b: + app = _make_simple_graph("process", process) + app.invoke({"value": 0}) + + assert "process" not in b._node_budgets + + +def test_multiple_nodes_spend_attributed_separately() -> None: + """Each node's spend is attributed to its own ComponentBudget independently.""" + + with budget(max_usd=5.00) as b: + b.node("n1", max_usd=1.00) + b.node("n2", max_usd=1.00) + + def n1(state: _State) -> dict: + b._spent += 0.10 + return {"value": state["value"] + 1} + + def n2(state: _State) -> dict: + b._spent += 0.30 + return {"value": state["value"] + 1} + + g = StateGraph(_State) + g.add_node("n1", n1) + g.add_node("n2", n2) + g.set_entry_point("n1") + g.add_edge("n1", "n2") + g.add_edge("n2", END) + app = g.compile() + app.invoke({"value": 0}) + + assert b._node_budgets["n1"]._spent == pytest.approx(0.10) + assert b._node_budgets["n2"]._spent == pytest.approx(0.30) + + +# --------------------------------------------------------------------------- +# Group 7: Async nodes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_async_no_active_budget_node_runs_unguarded() -> None: + """Async node runs normally when invoked outside any budget context.""" + + async def async_process(state: _State) -> dict: + return {"value": 77} + + with budget(max_usd=5.00): + g = StateGraph(_State) + g.add_node("async_process", async_process) + g.set_entry_point("async_process") + g.add_edge("async_process", END) + app = g.compile() + + # Invoke outside any budget — get_active_budget() returns None + result = await app.ainvoke({"value": 0}) + assert result["value"] == 77 + + +@pytest.mark.asyncio +async def test_async_node_runs_and_returns_result() -> None: + """Async node wrapped by gate still executes and returns its result.""" + + async def async_fetch(state: _State) -> dict: + return {"value": state["value"] + 5} + + async with budget(max_usd=5.00): + g = StateGraph(_State) + g.add_node("async_fetch", async_fetch) + g.set_entry_point("async_fetch") + g.add_edge("async_fetch", END) + app = g.compile() + result = await app.ainvoke({"value": 0}) + + assert result["value"] == 5 + + +@pytest.mark.asyncio +async def test_async_node_cap_exceeded_raises_before_execution() -> None: + """Async node gate raises NodeBudgetExceededError before awaiting node body.""" + executed: list[bool] = [] + + async def async_fetch(state: _State) -> dict: + executed.append(True) + return {"value": 0} + + async with budget(max_usd=5.00) as b: + b.node("async_fetch", max_usd=0.10) + b._node_budgets["async_fetch"]._spent = 0.10 + + g = StateGraph(_State) + g.add_node("async_fetch", async_fetch) + g.set_entry_point("async_fetch") + g.add_edge("async_fetch", END) + app = g.compile() + + with pytest.raises(NodeBudgetExceededError): + await app.ainvoke({"value": 0}) + + assert executed == [] + + +@pytest.mark.asyncio +async def test_async_spend_attributed_to_component_budget() -> None: + """Spend during async node execution is attributed to its ComponentBudget.""" + + async with budget(max_usd=5.00) as b: + b.node("async_fetch", max_usd=1.00) + + async def async_fetch(state: _State) -> dict: + b._spent += 0.20 + return {"value": 0} + + g = StateGraph(_State) + g.add_node("async_fetch", async_fetch) + g.set_entry_point("async_fetch") + g.add_edge("async_fetch", END) + app = g.compile() + await app.ainvoke({"value": 0}) + + assert b._node_budgets["async_fetch"]._spent == pytest.approx(0.20) + + +# --------------------------------------------------------------------------- +# Group 8: Integration with Budget lifecycle via ShekelRuntime +# --------------------------------------------------------------------------- + + +def test_add_node_patched_on_budget_enter() -> None: + """StateGraph.add_node is patched when a budget context is entered.""" + original = StateGraph.add_node + with budget(max_usd=5.00): + assert StateGraph.add_node is not original + assert StateGraph.add_node is original + + +def test_add_node_restored_on_budget_exit() -> None: + """StateGraph.add_node is restored after the budget context exits.""" + original = StateGraph.add_node + with budget(max_usd=5.00): + pass + assert StateGraph.add_node is original + + +def test_mock_llm_spend_tracked_with_node_cap() -> None: + """End-to-end: mocked LLM call inside node → spend tracked in ComponentBudget.""" + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "ok" + mock_resp.usage.prompt_tokens = 100 + mock_resp.usage.completion_tokens = 50 + mock_resp.model = "gpt-4o-mini" + + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + import openai + + client = openai.OpenAI(api_key="test") + + with budget(max_usd=5.00) as b: + b.node("llm_node", max_usd=1.00) + + def llm_node(state: _State) -> dict: + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "hello"}], + ) + return {"value": state["value"] + 1} + + app = _make_simple_graph("llm_node", llm_node) + app.invoke({"value": 0}) + + assert b._node_budgets["llm_node"]._spent > 0 + assert b._node_budgets["llm_node"]._spent == pytest.approx(b.spent) + + +# --------------------------------------------------------------------------- +# Group 9: Nested budget cap inheritance (parent-chain lookup) +# --------------------------------------------------------------------------- + + +def test_node_cap_on_outer_budget_enforced_in_nested_inner_budget() -> None: + """Cap registered on outer budget raises NodeBudgetExceededError inside inner context.""" + executed: list[bool] = [] + + def fetch(state: _State) -> dict: + executed.append(True) + return {"value": 0} + + with budget(max_usd=5.00, name="outer") as outer: + outer.node("fetch", max_usd=0.10) + outer._node_budgets["fetch"]._spent = 0.10 # exhaust the cap + + with budget(max_usd=2.00, name="inner"): + app = _make_simple_graph("fetch", fetch) + with pytest.raises(NodeBudgetExceededError): + app.invoke({"value": 0}) + + assert executed == [] + + +def test_node_cap_spend_attributed_to_outer_budget_when_invoked_in_inner_context() -> None: + """Spend delta attributed to outer budget's ComponentBudget when cap is registered there.""" + with budget(max_usd=5.00, name="outer") as outer: + outer.node("fetch", max_usd=1.00) + + with budget(max_usd=2.00, name="inner") as inner: + + def fetch(state: _State) -> dict: + inner._spent += 0.15 # simulate LLM spend on inner budget + return {"value": 0} + + app = _make_simple_graph("fetch", fetch) + app.invoke({"value": 0}) + + assert outer._node_budgets["fetch"]._spent == pytest.approx(0.15) + + +def test_node_cap_found_on_grandparent_budget() -> None: + """Cap registered on grandparent is found when graph runs two levels deep.""" + executed: list[bool] = [] + + def fetch(state: _State) -> dict: + executed.append(True) + return {"value": 0} + + with budget(max_usd=10.00, name="root") as root: + root.node("fetch", max_usd=0.10) + root._node_budgets["fetch"]._spent = 0.10 + + with budget(max_usd=5.00, name="mid"): + with budget(max_usd=2.00, name="inner"): + app = _make_simple_graph("fetch", fetch) + with pytest.raises(NodeBudgetExceededError): + app.invoke({"value": 0}) + + assert executed == [] + + +def test_inner_budget_node_cap_takes_precedence_over_outer() -> None: + """Cap registered on inner budget is used even if outer also has a cap for the same node.""" + + def fetch(state: _State) -> dict: + return {"value": 0} + + with budget(max_usd=5.00, name="outer") as outer: + outer.node("fetch", max_usd=1.00) # outer cap: $1.00 — not exceeded + + with budget(max_usd=2.00, name="inner") as inner: + inner.node("fetch", max_usd=0.05) + inner._node_budgets["fetch"]._spent = 0.05 # exhaust inner cap + + app = _make_simple_graph("fetch", fetch) + with pytest.raises(NodeBudgetExceededError) as exc_info: + app.invoke({"value": 0}) + + # Raised using the inner cap ($0.05), not the outer cap ($1.00) + assert exc_info.value.limit == pytest.approx(0.05) + + +def test_looping_node_circuit_breaks_on_parent_budget() -> None: + """A node in a retry loop is stopped when parent budget is exhausted.""" + mock_resp = MagicMock() + mock_resp.choices[0].message.content = "ok" + mock_resp.usage.prompt_tokens = 10000 # high token count to exhaust budget fast + mock_resp.usage.completion_tokens = 5000 + mock_resp.model = "gpt-4o-mini" + + from typing_extensions import TypedDict + + class _LoopState(TypedDict): + count: int + + with patch( + "openai.resources.chat.completions.Completions.create", + return_value=mock_resp, + ): + import openai + + client = openai.OpenAI(api_key="test") + + def loop_node(state: _LoopState) -> dict: + client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "x"}], + ) + return {"count": state["count"] + 1} + + def should_loop(state: Any) -> str: + return "done" if state["count"] >= 100 else "loop" + + g = StateGraph(_LoopState) + g.add_node("loop_node", loop_node) + g.set_entry_point("loop_node") + g.add_conditional_edges("loop_node", should_loop, {"loop": "loop_node", "done": END}) + app = g.compile() + + with pytest.raises((BudgetExceededError, NodeBudgetExceededError)): + with budget( + max_usd=0.001, + price_per_1k_tokens={"input": 1.0, "output": 1.0}, + ): + app.invoke({"count": 0}) diff --git a/tests/test_runtime.py b/tests/test_runtime.py new file mode 100644 index 0000000..d816c1e --- /dev/null +++ b/tests/test_runtime.py @@ -0,0 +1,482 @@ +"""Tests for ShekelRuntime and component budget API (v0.3.1). + +Domain: runtime framework detection scaffold and per-component budget registration. +""" + +from __future__ import annotations + +import pytest + +from shekel import budget +from shekel._budget import Budget, ComponentBudget +from shekel._runtime import ShekelRuntime +from shekel.exceptions import ( + AgentBudgetExceededError, + BudgetExceededError, + NodeBudgetExceededError, + SessionBudgetExceededError, + TaskBudgetExceededError, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def clean_runtime_registry(): + """Save, clear, and restore ShekelRuntime._adapter_registry between tests. + + Clearing at the start keeps these unit tests isolated from framework + adapters (e.g. LangGraphAdapter) registered at import time. Tests that + need a specific adapter register it themselves. + """ + original = ShekelRuntime._adapter_registry[:] + ShekelRuntime._adapter_registry = [] + yield + ShekelRuntime._adapter_registry = original + + +# --------------------------------------------------------------------------- +# Exception hierarchy +# --------------------------------------------------------------------------- + + +def test_node_budget_exceeded_error_is_budget_exceeded_error() -> None: + err = NodeBudgetExceededError(node_name="fetch", spent=0.05, limit=0.01) + assert isinstance(err, BudgetExceededError) + + +def test_agent_budget_exceeded_error_is_budget_exceeded_error() -> None: + err = AgentBudgetExceededError(agent_name="researcher", spent=1.50, limit=1.00) + assert isinstance(err, BudgetExceededError) + + +def test_task_budget_exceeded_error_is_budget_exceeded_error() -> None: + err = TaskBudgetExceededError(task_name="write_report", spent=0.60, limit=0.50) + assert isinstance(err, BudgetExceededError) + + +def test_session_budget_exceeded_error_is_budget_exceeded_error() -> None: + err = SessionBudgetExceededError(agent_name="assistant", spent=6.00, limit=5.00) + assert isinstance(err, BudgetExceededError) + + +def test_node_error_carries_node_name_spent_limit() -> None: + err = NodeBudgetExceededError(node_name="summarize", spent=0.20, limit=0.10) + assert err.node_name == "summarize" + assert err.spent == pytest.approx(0.20) + assert err.limit == pytest.approx(0.10) + + +def test_agent_error_carries_agent_name_spent_limit() -> None: + err = AgentBudgetExceededError(agent_name="writer", spent=2.00, limit=1.50) + assert err.agent_name == "writer" + assert err.spent == pytest.approx(2.00) + assert err.limit == pytest.approx(1.50) + + +def test_task_error_carries_task_name_spent_limit() -> None: + err = TaskBudgetExceededError(task_name="research", spent=0.80, limit=0.50) + assert err.task_name == "research" + assert err.spent == pytest.approx(0.80) + assert err.limit == pytest.approx(0.50) + + +def test_session_error_carries_agent_name_and_window() -> None: + err = SessionBudgetExceededError(agent_name="bot", spent=6.00, limit=5.00, window=86400.0) + assert err.agent_name == "bot" + assert err.spent == pytest.approx(6.00) + assert err.limit == pytest.approx(5.00) + assert err.window == pytest.approx(86400.0) + + +def test_session_error_window_defaults_to_none() -> None: + err = SessionBudgetExceededError(agent_name="bot", spent=6.00, limit=5.00) + assert err.window is None + + +def test_existing_except_budget_exceeded_error_catches_node_error() -> None: + """Existing catch-all still works after adding subclasses.""" + with pytest.raises(BudgetExceededError): + raise NodeBudgetExceededError(node_name="x", spent=1.0, limit=0.5) + + +def test_existing_except_budget_exceeded_error_catches_agent_error() -> None: + with pytest.raises(BudgetExceededError): + raise AgentBudgetExceededError(agent_name="x", spent=1.0, limit=0.5) + + +def test_existing_except_budget_exceeded_error_catches_task_error() -> None: + with pytest.raises(BudgetExceededError): + raise TaskBudgetExceededError(task_name="x", spent=1.0, limit=0.5) + + +def test_existing_except_budget_exceeded_error_catches_session_error() -> None: + with pytest.raises(BudgetExceededError): + raise SessionBudgetExceededError(agent_name="x", spent=1.0, limit=0.5) + + +# --------------------------------------------------------------------------- +# ShekelRuntime — registry and probe/release +# --------------------------------------------------------------------------- + + +def test_runtime_probe_is_noop_with_empty_registry() -> None: + """probe() with no registered adapters does nothing and does not raise.""" + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() # must not raise + assert runtime._active_adapters == [] + + +def test_runtime_register_adds_to_registry() -> None: + class FakeAdapter: + def install_patches(self, budget: Budget) -> None: + pass + + def remove_patches(self, budget: Budget) -> None: + pass + + ShekelRuntime.register(FakeAdapter) + assert FakeAdapter in ShekelRuntime._adapter_registry + + +def test_runtime_probe_activates_registered_adapter() -> None: + installed: list[Budget] = [] + + class FakeAdapter: + def install_patches(self, b: Budget) -> None: + installed.append(b) + + def remove_patches(self, b: Budget) -> None: + pass + + ShekelRuntime.register(FakeAdapter) + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() + + assert len(installed) == 1 + assert installed[0] is b + assert len(runtime._active_adapters) == 1 + + +def test_runtime_probe_skips_adapter_that_raises_import_error() -> None: + class BadAdapter: + def install_patches(self, b: Budget) -> None: + raise ImportError("framework not installed") + + def remove_patches(self, b: Budget) -> None: + pass + + ShekelRuntime.register(BadAdapter) + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() # must not raise + + assert runtime._active_adapters == [] + + +def test_runtime_release_calls_remove_patches_on_active_adapters() -> None: + removed: list[Budget] = [] + + class FakeAdapter: + def install_patches(self, b: Budget) -> None: + pass + + def remove_patches(self, b: Budget) -> None: + removed.append(b) + + ShekelRuntime.register(FakeAdapter) + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() + runtime.release() + + assert len(removed) == 1 + assert removed[0] is b + assert runtime._active_adapters == [] + + +def test_runtime_release_clears_active_adapters() -> None: + class FakeAdapter: + def install_patches(self, b: Budget) -> None: + pass + + def remove_patches(self, b: Budget) -> None: + pass + + ShekelRuntime.register(FakeAdapter) + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() + assert len(runtime._active_adapters) == 1 + runtime.release() + assert runtime._active_adapters == [] + + +def test_runtime_release_tolerates_remove_patches_exception() -> None: + """release() does not propagate exceptions from remove_patches.""" + + class BrokenAdapter: + def install_patches(self, b: Budget) -> None: + pass + + def remove_patches(self, b: Budget) -> None: + raise RuntimeError("unexpected error during cleanup") + + ShekelRuntime.register(BrokenAdapter) + b = Budget(max_usd=5.00) + runtime = ShekelRuntime(b) + runtime.probe() + runtime.release() # must not raise + + +# --------------------------------------------------------------------------- +# ShekelRuntime — integration with Budget lifecycle +# --------------------------------------------------------------------------- + + +def test_runtime_probe_called_on_budget_enter() -> None: + probed: list[Budget] = [] + + class TrackingAdapter: + def install_patches(self, b: Budget) -> None: + probed.append(b) + + def remove_patches(self, b: Budget) -> None: + pass + + ShekelRuntime.register(TrackingAdapter) + + with budget(max_usd=5.00) as b: + assert len(probed) == 1 + assert probed[0] is b + + +def test_runtime_release_called_on_budget_exit() -> None: + released: list[Budget] = [] + + class TrackingAdapter: + def install_patches(self, b: Budget) -> None: + pass + + def remove_patches(self, b: Budget) -> None: + released.append(b) + + ShekelRuntime.register(TrackingAdapter) + + with budget(max_usd=5.00) as b: + pass + + assert len(released) == 1 + assert released[0] is b + + +def test_runtime_release_called_on_budget_exit_even_on_exception() -> None: + released: list[bool] = [] + + class TrackingAdapter: + def install_patches(self, b: Budget) -> None: + pass + + def remove_patches(self, b: Budget) -> None: + released.append(True) + + ShekelRuntime.register(TrackingAdapter) + + try: + with budget(max_usd=5.00): + raise ValueError("simulated error") + except ValueError: + pass + + assert released == [True] + + +async def _async_budget_helper(probed: list[Budget]) -> None: + class TrackingAdapter: + def install_patches(self, b: Budget) -> None: + probed.append(b) + + def remove_patches(self, b: Budget) -> None: + pass + + ShekelRuntime.register(TrackingAdapter) + + async with budget(max_usd=5.00) as b: + assert len(probed) == 1 + assert probed[0] is b + + +@pytest.mark.asyncio +async def test_runtime_probe_called_on_async_budget_enter() -> None: + probed: list[Budget] = [] + await _async_budget_helper(probed) + + +# --------------------------------------------------------------------------- +# Budget.node() / .agent() / .task() API +# --------------------------------------------------------------------------- + + +def test_budget_node_registers_component_budget() -> None: + b = Budget(max_usd=5.00) + b.node("fetch_data", max_usd=0.50) + assert "fetch_data" in b._node_budgets + assert b._node_budgets["fetch_data"].max_usd == pytest.approx(0.50) + + +def test_budget_agent_registers_component_budget() -> None: + b = Budget(max_usd=5.00) + b.agent("researcher", max_usd=1.50) + assert "researcher" in b._agent_budgets + assert b._agent_budgets["researcher"].max_usd == pytest.approx(1.50) + + +def test_budget_task_registers_component_budget() -> None: + b = Budget(max_usd=5.00) + b.task("write_report", max_usd=0.50) + assert "write_report" in b._task_budgets + assert b._task_budgets["write_report"].max_usd == pytest.approx(0.50) + + +def test_component_methods_return_self_for_chaining() -> None: + b = Budget(max_usd=5.00) + result = b.node("a", max_usd=0.50).agent("b", max_usd=1.00).task("c", max_usd=0.30) + assert result is b + + +def test_node_max_usd_must_be_positive() -> None: + b = Budget(max_usd=5.00) + with pytest.raises(ValueError, match="positive"): + b.node("fetch", max_usd=0.0) + + +def test_node_max_usd_must_not_be_negative() -> None: + b = Budget(max_usd=5.00) + with pytest.raises(ValueError, match="positive"): + b.node("fetch", max_usd=-1.0) + + +def test_agent_max_usd_must_be_positive() -> None: + b = Budget(max_usd=5.00) + with pytest.raises(ValueError, match="positive"): + b.agent("researcher", max_usd=0.0) + + +def test_task_max_usd_must_be_positive() -> None: + b = Budget(max_usd=5.00) + with pytest.raises(ValueError, match="positive"): + b.task("write", max_usd=0.0) + + +def test_component_budgets_accessible_via_internal_dicts() -> None: + b = Budget(max_usd=5.00) + b.node("n1", max_usd=0.50) + b.agent("a1", max_usd=1.00) + b.task("t1", max_usd=0.30) + + assert isinstance(b._node_budgets["n1"], ComponentBudget) + assert isinstance(b._agent_budgets["a1"], ComponentBudget) + assert isinstance(b._task_budgets["t1"], ComponentBudget) + + +def test_component_budgets_can_be_registered_before_enter() -> None: + """Registering component caps before opening the context is valid.""" + b = Budget(max_usd=5.00) + b.node("fetch", max_usd=0.50) + with b: + assert "fetch" in b._node_budgets + + +def test_component_budgets_can_be_registered_inside_context() -> None: + """Registering component caps inside the context is also valid.""" + with budget(max_usd=5.00) as b: + b.node("fetch", max_usd=0.50) + assert "fetch" in b._node_budgets + + +def test_component_budget_initial_spent_is_zero() -> None: + b = Budget(max_usd=5.00) + b.node("fetch", max_usd=0.50) + cb = b._node_budgets["fetch"] + assert cb._spent == pytest.approx(0.0) + + +def test_overwriting_node_budget_replaces_previous() -> None: + b = Budget(max_usd=5.00) + b.node("fetch", max_usd=0.50) + b.node("fetch", max_usd=1.00) + assert b._node_budgets["fetch"].max_usd == pytest.approx(1.00) + + +# --------------------------------------------------------------------------- +# ComponentBudget dataclass +# --------------------------------------------------------------------------- + + +def test_component_budget_has_name_and_max_usd() -> None: + cb = ComponentBudget(name="my_node", max_usd=0.50) + assert cb.name == "my_node" + assert cb.max_usd == pytest.approx(0.50) + + +def test_component_budget_spent_starts_at_zero() -> None: + cb = ComponentBudget(name="my_node", max_usd=0.50) + assert cb._spent == pytest.approx(0.0) + + +# --------------------------------------------------------------------------- +# tree() output includes component budgets +# --------------------------------------------------------------------------- + + +def test_tree_includes_registered_node_budgets() -> None: + with budget(max_usd=5.00) as b: + b.node("fetch_data", max_usd=0.50) + output = b.tree() + assert "fetch_data" in output + assert "node" in output + + +def test_tree_includes_registered_agent_budgets() -> None: + with budget(max_usd=5.00) as b: + b.agent("researcher", max_usd=1.50) + output = b.tree() + assert "researcher" in output + assert "agent" in output + + +def test_tree_includes_registered_task_budgets() -> None: + with budget(max_usd=5.00) as b: + b.task("write_report", max_usd=0.50) + output = b.tree() + assert "write_report" in output + assert "task" in output + + +def test_tree_shows_zero_spend_for_fresh_component_budgets() -> None: + with budget(max_usd=5.00) as b: + b.node("fetch_data", max_usd=0.50) + output = b.tree() + assert "$0.0000" in output + + +def test_tree_shows_multiple_component_types() -> None: + with budget(max_usd=5.00) as b: + b.node("n1", max_usd=0.50) + b.agent("a1", max_usd=1.00) + b.task("t1", max_usd=0.30) + output = b.tree() + assert "n1" in output + assert "a1" in output + assert "t1" in output + + +def test_tree_without_component_budgets_unchanged() -> None: + """tree() with no component budgets must still work (no regression).""" + with budget(max_usd=5.00) as b: + output = b.tree() + assert "unnamed" in output or output # just must not raise diff --git a/tests/test_temporal_budgets.py b/tests/test_temporal_budgets.py index f3847ca..2bbf183 100644 --- a/tests/test_temporal_budgets.py +++ b/tests/test_temporal_budgets.py @@ -155,7 +155,7 @@ def test_budget_exceeded_error_retry_after_defaults_none(): # --------------------------------------------------------------------------- -# Group D — InMemoryBackend +# Group D — InMemoryBackend (new multi-cap protocol) # --------------------------------------------------------------------------- @@ -163,48 +163,58 @@ def test_in_memory_backend_get_state_fresh(): from shekel._temporal import InMemoryBackend backend = InMemoryBackend() - assert backend.get_state("new_key") == (0.0, None) + assert backend.get_state("new_key") == {} def test_in_memory_backend_check_and_add_within_limit(): from shekel._temporal import InMemoryBackend backend = InMemoryBackend() - result = backend.check_and_add("budget1", 2.0, 5.0, 3600.0) - assert result is True - spent, window_start = backend.get_state("budget1") - assert spent == 2.0 - assert window_start is not None + allowed, exceeded = backend.check_and_add( + "budget1", + amounts={"usd": 2.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is True + assert exceeded is None + state = backend.get_state("budget1") + assert state["usd"] == pytest.approx(2.0) def test_in_memory_backend_check_and_add_exceeds_limit(): from shekel._temporal import InMemoryBackend backend = InMemoryBackend() - result = backend.check_and_add("budget1", 6.0, 5.0, 3600.0) - assert result is False - # State should remain unchanged (0.0, None) since we never accepted it - spent, window_start = backend.get_state("budget1") - assert spent == 0.0 - assert window_start is None + allowed, exceeded = backend.check_and_add( + "budget1", + amounts={"usd": 6.0}, + limits={"usd": 5.0}, + windows={"usd": 3600.0}, + ) + assert allowed is False + assert exceeded == "usd" + # State should remain empty since nothing was accepted + assert backend.get_state("budget1") == {} def test_in_memory_backend_reset_clears_state(): from shekel._temporal import InMemoryBackend backend = InMemoryBackend() - backend.check_and_add("budget1", 2.0, 5.0, 3600.0) + backend.check_and_add("budget1", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) backend.reset("budget1") - assert backend.get_state("budget1") == (0.0, None) + assert backend.get_state("budget1") == {} def test_in_memory_backend_window_start_set_on_first_add(): from shekel._temporal import InMemoryBackend backend = InMemoryBackend() - backend.check_and_add("budget1", 1.0, 5.0, 3600.0) - _, window_start = backend.get_state("budget1") - assert window_start is not None + backend.check_and_add("budget1", {"usd": 1.0}, {"usd": 5.0}, {"usd": 3600.0}) + state = backend.get_state("budget1") + assert "usd" in state + assert state["usd"] == pytest.approx(1.0) def test_in_memory_backend_window_expires_resets(): @@ -214,14 +224,16 @@ def test_in_memory_backend_window_expires_resets(): t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("budget1", 4.0, 5.0, 3600.0) + backend.check_and_add("budget1", {"usd": 4.0}, {"usd": 5.0}, {"usd": 3600.0}) # Advance time past window with patch("time.monotonic", return_value=t0 + 3601.0): - result = backend.check_and_add("budget1", 4.0, 5.0, 3600.0) - assert result is True # fresh window, 4.0 < 5.0 so should succeed - spent, _ = backend.get_state("budget1") - assert spent == 4.0 # fresh window + allowed, exceeded = backend.check_and_add( + "budget1", {"usd": 4.0}, {"usd": 5.0}, {"usd": 3600.0} + ) + assert allowed is True # fresh window, 4.0 < 5.0 so should succeed + state = backend.get_state("budget1") + assert state["usd"] == pytest.approx(4.0) # fresh window # --------------------------------------------------------------------------- @@ -255,12 +267,12 @@ def test_temporal_budget_window_resets_after_expiry(): # Fill the window with patch("time.monotonic", return_value=t0): - backend.check_and_add("test", 4.5, 5.0, 3600.0) + backend.check_and_add("test", {"usd": 4.5}, {"usd": 5.0}, {"usd": 3600.0}) # After window expires, should be able to spend again with patch("time.monotonic", return_value=t0 + 3601.0): - result = backend.check_and_add("test", 3.0, 5.0, 3600.0) - assert result is True + allowed, _ = backend.check_and_add("test", {"usd": 3.0}, {"usd": 5.0}, {"usd": 3600.0}) + assert allowed is True def test_temporal_budget_retry_after_in_error(): @@ -273,7 +285,7 @@ def test_temporal_budget_retry_after_in_error(): with patch("time.monotonic", return_value=t0): # Set up state: window started, nearly full - backend.check_and_add("test_retry", 4.5, 5.0, 3600.0) + backend.check_and_add("test_retry", {"usd": 4.5}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 100.0): with pytest.raises(BudgetExceededError) as exc_info: @@ -291,26 +303,26 @@ def test_temporal_budget_window_spent_in_error(): t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("test_ws", 4.5, 5.0, 3600.0) + backend.check_and_add("test_ws", {"usd": 4.5}, {"usd": 5.0}, {"usd": 3600.0}) with patch("time.monotonic", return_value=t0 + 100.0): with pytest.raises(BudgetExceededError) as exc_info: tb._record_spend(1.0, "test-model", {"input": 100, "output": 100}) - assert exc_info.value.window_spent == 4.5 + assert exc_info.value.window_spent == pytest.approx(4.5) def test_record_spend_window_expired_mid_context(): from shekel._temporal import InMemoryBackend, TemporalBudget from shekel.exceptions import BudgetExceededError - # Window expires between __enter__ and _record_spend (lines 158-159) + # Window expires between __enter__ and _record_spend backend = InMemoryBackend() tb = TemporalBudget(max_usd=5.0, window_seconds=3600, name="mid_expiry", backend=backend) t0 = 1000.0 # Seed backend with state (window started at t0) with patch("time.monotonic", return_value=t0): - backend.check_and_add("mid_expiry", 3.0, 5.0, 3600.0) + backend.check_and_add("mid_expiry", {"usd": 3.0}, {"usd": 5.0}, {"usd": 3600.0}) # At t0+3601 window has expired; attempt to add more than limit → raises with patch("time.monotonic", return_value=t0 + 3601.0): @@ -342,7 +354,7 @@ def test_lazy_window_reset_emit_guard(): t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("guard_test", 2.0, 5.0, 3600.0) + backend.check_and_add("guard_test", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch( "shekel.integrations.registry.AdapterRegistry.emit_event", @@ -477,7 +489,7 @@ def test_window_reset_event_emitted(): t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("test_emit", 2.0, 5.0, 3600.0) + backend.check_and_add("test_emit", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch.object(AdapterRegistry, "emit_event") as mock_emit: with patch("time.monotonic", return_value=t0 + 3601.0): @@ -498,7 +510,7 @@ def test_window_reset_event_payload(): t0 = 1000.0 with patch("time.monotonic", return_value=t0): - backend.check_and_add("test_payload", 2.0, 5.0, 3600.0) + backend.check_and_add("test_payload", {"usd": 2.0}, {"usd": 5.0}, {"usd": 3600.0}) with patch.object(AdapterRegistry, "emit_event") as mock_emit: with patch("time.monotonic", return_value=t0 + 3601.0): diff --git a/tmp-h-design/PRD-hierarchical-integration.md b/tmp-h-design/PRD-hierarchical-integration.md new file mode 100644 index 0000000..48c95de --- /dev/null +++ b/tmp-h-design/PRD-hierarchical-integration.md @@ -0,0 +1,235 @@ +# PRD: Hierarchical Budget Enforcement for Shekel + +**Status:** Draft +**Date:** 2026-03-15 +**Author:** Elish +**Version:** 1.0 + +--- + +## 1. Problem Statement + +Shekel today is a circuit breaker at the LLM call and tool call level. It stops individual expensive calls. It cannot stop: + +- A LangGraph node that loops 5,000 times at $0.001/call — no single call trips the breaker, but the run costs $5 +- A rogue CrewAI agent monopolizing crew budget while other agents sit idle +- An OpenClaw always-on agent silently accumulating $3,600/month from heartbeat calls + +These are real, documented failures. Developers are losing hundreds to thousands of dollars to patterns that a call-level circuit breaker is architecturally incapable of catching. + +The market response — Langfuse, LangSmith, Helicone, Portkey — is **observability** (show you what happened). Shekel's position is **enforcement** (stop it before it happens). No tool in the market today provides in-process, per-component, multi-level budget enforcement without a proxy server. + +--- + +## 2. Goals + +1. Extend shekel's circuit breaker from Level 1 (LLM/tool calls) to all four abstraction levels present in the running process +2. Maintain 100% backward compatibility — existing `budget()` usage must not change +3. Zero configuration for the common case — auto-detect and auto-instrument all available frameworks +4. Progressive disclosure — explicit per-component caps available when needed, not required +5. Keep the install story: `pip install shekel`, one line of code + +## 3. Non-Goals + +- Observability / cost dashboards (that's Langfuse, LangSmith) +- Proxy-based enforcement (that's LiteLLM, Portkey) +- Recovery logic — shekel throws, frameworks recover +- Supporting frameworks not yet released or documented +- Modifying any upstream framework (no PRs to LangGraph, CrewAI, etc.) + +--- + +## 4. Users + +**Primary: Python developers building LLM-powered agents** +- Use LangGraph, CrewAI, OpenAI Agents SDK, or OpenClaw +- Already know `with budget(max_usd=5): ...` +- Pain: unexpected bills from agent loops, rogue agents, always-on accumulation + +**Secondary: Teams deploying agents in production** +- Want org-level policy ("no single run exceeds $10") +- Want per-feature cost attribution for internal billing +- Want to adopt gradually: showback before chargeback + +--- + +## 5. User Stories + +### Phase 1 Foundation — ShekelRuntime (v0.3.1) + +**US-01:** As a developer, when I open a `budget()` context, shekel automatically detects which frameworks I have installed and instruments them — I don't need to configure anything. + +**US-02:** As a developer, I can declare explicit per-node, per-agent, or per-task caps with `b.node()`, `b.agent()`, `b.task()` — using vocabulary I already know from the framework I'm using. + +**US-03:** As a developer, `budget.tree()` shows me a spend breakdown by detected level (session → agent → node → call) so I can see which component is driving cost. + +### Phase 2 — LangGraph Layer (v0.3.2) + +**US-04:** As a LangGraph developer, when a node exceeds its cap, I get a `NodeBudgetExceededError` — not a generic `BudgetExceededError` — so I know exactly which node tripped and can handle it specifically. + +**US-05:** As a LangGraph developer, I can set per-node caps: +```python +with budget(max_usd=5.00) as b: + b.node("fetch_data", max_usd=0.50) + graph.invoke(...) +``` + +**US-06:** As a LangGraph developer, without any explicit node caps, the parent budget still enforces the total — and `budget.tree()` shows me per-node attribution so I can decide later which nodes to cap explicitly. + +**US-07:** As a LangGraph developer, a looping node that fires hundreds of times is circuit-broken before it exhausts the parent budget — no single call is expensive, but the loop is caught. + +### Phase 3 — CrewAI Layer (v0.3.3) + +**US-08:** As a CrewAI developer, individual agent caps prevent one agent from consuming the entire crew budget: +```python +with budget(max_usd=10.00) as b: + b.agent("researcher", max_usd=3.00) + b.agent("writer", max_usd=2.00) + crew.kickoff() +``` + +**US-09:** As a CrewAI developer, a task that exceeds its cap is circuit-broken *before* it starts (pre-task check on `TaskStartedEvent`) — zero wasted spend. + +**US-10:** As a CrewAI developer, I get `AgentBudgetExceededError` or `TaskBudgetExceededError` — not a generic error — so my error handling can make intelligent decisions (retry with cheaper model, skip task, abort crew). + +### Phase 4 — Loop Detection (v0.3.4) + +**US-11:** As a developer, if my agent's spend rate spikes to 3× its rolling baseline within a short window, shekel circuit-breaks before the absolute limit is hit — catching runaway loops early. + +**US-12:** As a developer, I can configure the velocity threshold: +```python +with budget(max_usd=10.00, loop_detection_multiplier=3.0, loop_detection_window=60): + graph.invoke(...) +``` + +### Phase 5 — Tiered Thresholds (v0.3.5) + +**US-13:** As a developer, I can configure N enforcement tiers instead of just warn + hard stop: +```python +with budget(max_usd=5.00, tiers=[ + (0.50, "warn"), + (0.75, "fallback:gpt-4o-mini"), + (0.90, "disable_tools"), + (1.00, "stop"), +]): + graph.invoke(...) +``` + +### Phase 6 — OpenClaw Layer (v0.3.6) + +**US-14:** As an OpenClaw developer, I can enforce a rolling-window budget per agent session: +```python +with budget("$5/day") as b: + b.agent("my_assistant", max_usd=2.00) + openclaw_agent.run() +``` + +**US-15:** As an OpenClaw developer, when an agent session budget is exhausted it is suspended — not killed. Other agents on the same Gateway continue running. + +### Phase 7 — DX Layer (v0.3.7) + +**US-16:** As a developer adopting shekel, I can start in showback mode — full tracking and attribution, zero enforcement — then flip to chargeback when I'm confident: +```python +with budget(max_usd=5.00, mode="showback"): # never raises + graph.invoke(...) +``` + +**US-17:** As a team lead, I can tag spend for cost attribution: +```python +@tool(price=0.01, tags=["search", "feature-x"]) +def web_search(query: str) -> str: ... + +budget.summary(group_by="tags") # cost by feature +``` + +--- + +## 6. Functional Requirements + +### FR-01: ShekelRuntime (v0.3.1) +- Probe for installed frameworks once at `budget.__enter__()` +- Detection is silent — no logs, no warnings if a framework is not installed +- `Budget` gains `.node(name, max_usd)`, `.agent(name, max_usd)`, `.task(name, max_usd)` methods +- Child budgets created by these methods roll up to the parent +- `budget.tree()` renders the full hierarchy + +### FR-02: LangGraph adapter (v0.3.2) +- Patch `StateGraph.add_node()` when `langgraph` is detected +- Every registered node function is wrapped with a pre-execution budget gate +- Gate checks: explicit node cap (if set) → parent budget remaining → record attribution +- `NodeBudgetExceededError` raised with fields: `node_name`, `spent`, `limit` +- Async node functions wrapped with async budget gate +- Patch is reference-counted: applied on first budget open, restored when last budget closes + +### FR-03: CrewAI adapter (v0.3.3) +- Register `ShekelEventListener(BaseEventListener)` on `crewai_event_bus` at budget open +- Deregister at budget close +- Subscribe to: `TaskStartedEvent`, `AgentExecutionStartedEvent`, `LLMCallCompletedEvent`, `TaskFailedEvent` +- Pre-task check raises `TaskBudgetExceededError` before task body executes +- Pre-agent check raises `AgentBudgetExceededError` before agent executes +- `AgentBudgetExceededError` fields: `agent_name`, `spent`, `limit` +- `TaskBudgetExceededError` fields: `task_name`, `spent`, `limit` + +### FR-04: Loop detection (v0.3.4) +- Track spend rate as rolling average over configurable window (default: 60s) +- Circuit-break if instantaneous rate exceeds `loop_detection_multiplier × baseline` (default: 3×) +- Configurable via `budget(loop_detection_multiplier=3.0, loop_detection_window=60)` +- Default: disabled (opt-in) + +### FR-05: Tiered thresholds (v0.3.5) +- `tiers` parameter accepts list of `(fraction, action)` tuples +- Actions: `"warn"`, `"fallback:"`, `"disable_tools"`, `"stop"` +- Applied at all active levels (parent and children) +- Backward compatible: existing `warn_at` + `fallback` still work + +### FR-06: OpenClaw adapter (v0.3.6) +- Detect `openclaw` package at budget open +- Hook into `openclaw-sdk` agent lifecycle events +- Use `TemporalBudget` as the default budget type for OpenClaw contexts +- Session circuit-break suspends agent, does not kill Gateway +- `SessionBudgetExceededError` fields: `agent_name`, `spent`, `limit`, `window` + +### FR-07: DX layer (v0.3.7) +- `mode="showback"` — track everything, raise nothing +- `mode="chargeback"` — default, existing behavior +- `tags` parameter on `@tool` decorator +- `budget.summary(group_by="tags")` output + +--- + +## 7. Non-Functional Requirements + +- **100% backward compatibility** — all existing tests must pass unchanged +- **100% code coverage** on all new code (project standard) +- **Zero required configuration** for implicit mode +- **Optional dependencies** — `langgraph`, `crewai`, `openclaw-sdk` are all optional; absence is silent +- **Thread and async safe** — `contextvars.ContextVar` used for budget propagation (already the pattern) +- **Performance** — budget gate overhead < 1ms per node/agent/task check + +--- + +## 8. Exception Hierarchy + +Extending the existing pattern: + +```python +BudgetExceededError # existing base — unchanged +├── NodeBudgetExceededError # new: LangGraph node +├── AgentBudgetExceededError # new: agent (CrewAI, OpenClaw) +├── TaskBudgetExceededError # new: task (CrewAI) +└── SessionBudgetExceededError # new: session (OpenClaw) + +ToolBudgetExceededError # existing — unchanged +``` + +All new exceptions inherit `BudgetExceededError`. Existing `except BudgetExceededError` catches everything. + +--- + +## 9. Success Metrics + +- LangGraph users can set per-node caps with one line of code +- CrewAI users can isolate rogue agent spend without stopping the crew +- `budget.tree()` shows per-level spend attribution out of the box +- Zero breaking changes — all v0.2.x tests pass on v0.3.1+ +- Competitor gap: AgentBudget (30 stars) has no hierarchy or per-task enforcement — shekel ships both diff --git a/tmp-h-design/design-hierarchical-integration.md b/tmp-h-design/design-hierarchical-integration.md new file mode 100644 index 0000000..9b98e24 --- /dev/null +++ b/tmp-h-design/design-hierarchical-integration.md @@ -0,0 +1,243 @@ +# Design Decision: Hierarchical Budget Enforcement + +**Status:** Draft +**Date:** 2026-03-15 +**Author:** Elish + +--- + +## Context + +Shekel currently enforces budgets as a flat circuit breaker at the LLM call and tool call level. This catches expensive individual calls but is blind to failure modes that only emerge at higher abstraction levels: + +- A LangGraph node looping 5,000 times at $0.001/call — no single call trips the breaker +- A rogue CrewAI agent monopolizing crew budget while other agents wait +- An OpenClaw always-on agent accumulating cost from silent heartbeat calls over hours + +Each failure mode requires a circuit breaker at a different level. The call-level breaker alone cannot protect against them. + +--- + +## Decision: Layered Circuit Breaker Architecture + +Shekel enforces budgets at every abstraction level that can be detected in the running process. Each level catches a distinct class of failure. + +### The Four Levels + +``` +Level 4 — Session/Orchestrator OpenClaw agents temporal accumulation +Level 3 — Agent/Task CrewAI, OpenAI Agents rogue agent monopolization +Level 2 — Node/Subgraph LangGraph loop/retry spirals +Level 1 — LLM/Tool call all frameworks single expensive call (today) +``` + +**Each level is an independent circuit breaker.** A trip at Level 2 (node) does not automatically kill the Level 3 (agent) or Level 4 (session) budget. The parent framework decides whether to absorb the exception, route to a fallback, or propagate upward. Shekel enforces; the framework recovers. + +**All detected levels are active simultaneously.** If LangGraph runs inside a CrewAI task inside an OpenClaw session, all three levels instrument and enforce independently, with child budgets rolling up to parents. + +--- + +## Detection Strategy + +Shekel probes for available frameworks **once, at budget open** (`__enter__` / `__aenter__`). This is the correct semantic moment: the developer has declared intent to enforce a budget, and frameworks in use are already imported. + +Detection order (top-down): + +```python +# ShekelRuntime.probe() — called on budget open +1. try: import openclaw → activate SessionAdapter +2. try: import crewai → activate CrewAIAdapter (BaseEventListener) +3. try: import langgraph → activate LangGraphAdapter (add_node patch) +4. try: import openai.agents → activate OpenAIAgentsAdapter (guardrail) +5. always → LLM + Tool adapters (today's behavior) +``` + +Frameworks not installed are silently skipped. No configuration required. + +--- + +## API Design + +### Principle: zero new syntax for the common case + +The existing `budget()` API is unchanged. Auto-instrumentation is silent. Developers who don't know about hierarchical integration get it for free. + +```python +# Nothing changes — shekel auto-instruments all detected levels +with budget(max_usd=5.00): + graph.invoke(...) +``` + +**Implicit mode behavior:** no per-component cap. Every detected level gets attribution and spend tracking. The parent budget is the only enforced limit. When the parent is exhausted, whichever component is running at that moment receives the exception. + +### Explicit overrides — only when you need per-component caps + +```python +with budget(max_usd=5.00) as b: + b.node("fetch_data", max_usd=0.50) # LangGraph node cap + b.node("summarize", max_usd=1.00) # LangGraph node cap + b.agent("researcher", max_usd=1.50) # CrewAI / OpenClaw agent cap + b.task("write_report", max_usd=0.50) # CrewAI task cap + graph.invoke(...) +``` + +**Method names mirror the framework's own vocabulary.** A LangGraph developer recognizes `.node()`. A CrewAI developer recognizes `.agent()` and `.task()`. No new vocabulary to learn. + +**Unregistered components** (nodes/agents/tasks without explicit caps) share the parent's remaining budget uncapped. They are always tracked and attributed. + +### Temporal budgets for always-on runtimes + +For OpenClaw and any persistent agent runtime, use the existing `TemporalBudget`: + +```python +# Rolling-window enforcement — resets every hour +with budget("$5/hr") as b: + b.agent("my_assistant", max_usd=2.00) + openclaw_agent.run() +``` + +Flat caps (`max_usd`) apply to single invocations. Temporal caps (`$N/hr`, `$N/day`) apply to always-on runtimes. Both can coexist in a hierarchy. + +--- + +## Auto-Instrumentation per Level + +Each level uses the framework's native extensibility hook. No framework internals are bypassed. + +### Level 1 — LLM/Tool calls (today) + +**Mechanism:** Monkey-patch provider SDK methods at budget open. Restore on budget close. +**Hook point:** `openai.resources.chat.completions.Completions.create`, `anthropic.messages.Messages.create`, etc. +**Status:** Implemented. + +### Level 2 — LangGraph nodes + +**Mechanism:** Patch `StateGraph.add_node()` at import time. Every node function is wrapped with a budget gate before it is registered in the graph. + +```python +# What shekel does transparently: +original_add_node = StateGraph.add_node + +def patched_add_node(self, node_name, fn, **kwargs): + wrapped = _budget_gate(node_name, fn) + return original_add_node(self, node_name, wrapped, **kwargs) +``` + +**Budget gate behavior:** +- Check if active budget has an explicit cap for this node → enforce it +- If no explicit cap → check parent budget remaining → enforce parent +- Record node attribution regardless + +**Why this hook:** fires before any node body runs, requires zero modification to user node functions, works for all node types (sync, async, subgraph). + +### Level 3 — CrewAI agents and tasks + +**Mechanism:** Instantiate `ShekelEventListener(BaseEventListener)` and register it on `crewai_event_bus` at budget open. Deregister on budget close. + +**Events subscribed:** + +| Event | Action | +|---|---| +| `TaskStartedEvent` | Check task budget cap → circuit-break before task runs if exceeded | +| `AgentExecutionStartedEvent` | Check agent budget cap → circuit-break before agent runs | +| `LLMCallCompletedEvent` | Accumulate spend against agent + task child budgets | +| `TaskFailedEvent` | Record failure attribution | + +**Why this hook:** CrewAI's intended extensibility mechanism. Shekel becomes a first-class crew citizen, not a monkey-patch. Pre-task event fires before any LLM call is made — clean circuit break with zero wasted spend. + +### Level 3 — OpenAI Agents SDK + +**Mechanism:** Register a `ShekelBudgetGuardrail` as an input guardrail on `Runner` at budget open. + +**Guardrail behavior:** Before each agent turn, check remaining budget. If exhausted, raise `InputGuardrailTripwireTriggered` with budget context. The SDK's native exception handling propagates it to the caller. + +**Why this hook:** Repurposes the content-policy guardrail system as a cost-policy guardrail. No monkey-patching required. + +### Level 4 — OpenClaw sessions + +**Mechanism:** Hook into `openclaw-sdk` agent lifecycle at budget open via the SDK's agent event callbacks. + +**Budget type:** `TemporalBudget` (rolling window) — not flat cap. OpenClaw is an always-on runtime; per-session rolling budgets are the correct primitive. + +**Circuit break behavior:** When session budget is exhausted, the agent enters a suspended state. It does not kill the Gateway process. Other agents on the same Gateway continue unaffected. + +--- + +## Exception Hierarchy + +Consistent with shekel's existing pattern (`BudgetExceededError`, `ToolBudgetExceededError`): + +```python +BudgetExceededError # base — catch-all, today's exception +├── NodeBudgetExceededError # LangGraph node exceeded its cap +├── AgentBudgetExceededError # agent exceeded its cap (CrewAI, OpenClaw) +├── TaskBudgetExceededError # task exceeded its cap (CrewAI) +└── SessionBudgetExceededError # session exceeded rolling-window cap (OpenClaw) +``` + +**Catch-all pattern** — works with all existing code, no changes needed: +```python +try: + graph.invoke(...) +except BudgetExceededError as e: + print(f"Budget exceeded: {e.spent:.4f} / {e.limit:.4f} USD") +``` + +**Level-specific handling** — opt-in for recovery logic: +```python +try: + crew.kickoff() +except TaskBudgetExceededError as e: + print(f"Task '{e.task_name}' exceeded its ${e.limit:.2f} cap") + # retry with cheaper model, skip task, etc. +except AgentBudgetExceededError as e: + print(f"Agent '{e.agent_name}' exceeded its ${e.limit:.2f} cap") +except BudgetExceededError: + # crew-level budget exhausted + raise +``` + +All level-specific exceptions carry the same core fields as `BudgetExceededError` (`spent`, `limit`) plus a level-specific identifier (`node_name`, `agent_name`, `task_name`). + +--- + +## Propagation Contract + +**Shekel's responsibility:** throw the right exception at the right level. +**Framework's responsibility:** decide whether to absorb, route, or escalate. + +``` +NodeBudgetExceededError thrown + → LangGraph catches it (conditional edge / error node) OR + → propagates to AgentBudgetExceededError context (if inside CrewAI task) OR + → propagates to top-level budget() context +``` + +Shekel does not implement recovery logic. Recovery is the framework's domain — LangGraph has conditional edges, CrewAI has task callbacks, OpenAI Agents SDK has guardrail handlers. + +--- + +## Roadmap + +| Phase | Deliverable | +|---|---| +| v0.3.1 | `ShekelRuntime` — explicit detection + adapter wiring class. Two-tier implicit/explicit API. | +| v0.3.2 | LangGraph Level 2 — `add_node()` patch, `NodeBudgetExceededError`, `.node()` explicit API | +| v0.3.3 | CrewAI Level 3 — `BaseEventListener`, `AgentBudgetExceededError`, `TaskBudgetExceededError`, `.agent()` / `.task()` API | +| v0.3.4 | Loop detection — rate-of-change circuit breaker (velocity-based, not just threshold) | +| v0.3.5 | Tiered thresholds — N enforcement tiers (warn / fallback / disable tools / hard stop) | +| v0.3.6 | OpenClaw Level 4 — `openclaw-sdk` adapter, `SessionBudgetExceededError`, `TemporalBudget` per agent | +| v0.3.7 | Showback mode, budget tags, system-wide `ulimit`-style defaults | +| Post-1.0 | Cross-process budget spans, sampling mode, priority-aware preemption | + +--- + +## Open Questions + +These require further investigation before implementation: + +- **Q1:** Does Lobster (OpenClaw's YAML workflow engine) expose Python hooks for task start/end, or is all orchestration opaque to the SDK? +- **Q2:** In LangGraph, does `add_node()` patching handle async node functions and subgraph nodes correctly in all versions? +- **Q3:** When a CrewAI crew runs tasks in parallel, does `BaseEventListener` fire on the correct thread/async context for `contextvars` budget lookup? +- **Q4:** Can `ShekelBudgetGuardrail` (OpenAI Agents SDK) access the ContextVar budget from a guardrail callback, or does the SDK run guardrails in a separate context? +- **Q5:** What is the correct state value for a LangGraph node that is circuit-broken mid-execution — empty dict, sentinel, or should shekel inject a configurable default? diff --git a/tmp-h-design/implementation-plan.md b/tmp-h-design/implementation-plan.md new file mode 100644 index 0000000..f0769e7 --- /dev/null +++ b/tmp-h-design/implementation-plan.md @@ -0,0 +1,543 @@ +# Implementation Plan: Hierarchical Budget Enforcement + +**Branch:** `feat/hierarchical-integration` +**Design doc:** `tmp-h-design/design-hierarchical-integration.md` +**PRD:** `tmp-h-design/PRD-hierarchical-integration.md` + +Each phase is independently deliverable and releasable. Later phases depend on Phase 0 (ShekelRuntime) but not on each other. + +--- + +## Phase 0 — ShekelRuntime Foundation (v0.3.1.1) + +**Delivers:** Detection infrastructure and explicit API surface. Prerequisite for all later phases. + +### What ships + +- `ShekelRuntime` class — owns framework detection and adapter wiring +- `Budget.node()`, `Budget.agent()`, `Budget.task()` — explicit per-component cap API +- Enhanced `budget.tree()` — renders full hierarchy including node/agent/task child budgets + +### Files to create + +- `shekel/_runtime.py` — `ShekelRuntime` class + +### Files to modify + +- `shekel/_budget.py` — add `.node()`, `.agent()`, `.task()` methods; call `ShekelRuntime.probe()` on `__enter__` +- `shekel/__init__.py` — export new exception classes +- `shekel/exceptions.py` — add `NodeBudgetExceededError`, `AgentBudgetExceededError`, `TaskBudgetExceededError`, `SessionBudgetExceededError` + +### Key design + +```python +# shekel/_runtime.py +class ShekelRuntime: + """Probes for installed frameworks and wires adapters at budget open.""" + + ADAPTERS: list[type[ProviderAdapter]] = [] # populated by each phase + + @classmethod + def probe(cls, budget: Budget) -> None: + """Called once on budget.__enter__(). Activates all detected adapters.""" + for adapter_cls in cls.ADAPTERS: + try: + adapter = adapter_cls() + adapter.install_patches(budget) + except ImportError: + pass # framework not installed — silent skip +``` + +```python +# Budget gains: +def node(self, name: str, max_usd: float) -> Budget: + """Register an explicit cap for a LangGraph node.""" + child = Budget(max_usd=max_usd, name=f"node:{name}", parent=self) + self._node_budgets[name] = child + return child + +def agent(self, name: str, max_usd: float) -> Budget: + """Register an explicit cap for an agent (CrewAI / OpenClaw).""" + child = Budget(max_usd=max_usd, name=f"agent:{name}", parent=self) + self._agent_budgets[name] = child + return child + +def task(self, name: str, max_usd: float) -> Budget: + """Register an explicit cap for a task (CrewAI).""" + child = Budget(max_usd=max_usd, name=f"task:{name}", parent=self) + self._task_budgets[name] = child + return child +``` + +### TDD test file + +`tests/test_runtime.py` +- Test: `ShekelRuntime.probe()` runs without error when no frameworks installed +- Test: `Budget.node()` creates a child budget with correct limit +- Test: `Budget.agent()` creates a child budget with correct limit +- Test: `Budget.task()` creates a child budget with correct limit +- Test: `budget.tree()` includes node/agent/task children + +--- + +## Phase 1 — LangGraph Layer (v0.3.2) + +**Delivers:** Node-level circuit breaking for LangGraph. First framework integration. Highest-priority unmet need. + +### What ships + +- `LangGraphAdapter` — patches `StateGraph.add_node()` transparently +- `NodeBudgetExceededError` — level-specific exception with `node_name` +- Per-node explicit caps via `b.node("name", max_usd=X)` +- Implicit mode: full attribution + parent enforcement, no per-node cap required +- Async node support +- `budget.tree()` shows per-node spend + +### Dependency + +- Phase 0 (ShekelRuntime) must be complete + +### Files to create + +- `shekel/providers/langgraph.py` — `LangGraphAdapter` +- `tests/test_langgraph_wrappers.py` — all LangGraph integration tests + +### Files to modify + +- `shekel/_runtime.py` — register `LangGraphAdapter` in `ADAPTERS` +- `shekel/exceptions.py` — `NodeBudgetExceededError` (already added in Phase 0) + +--- + +### Detailed Implementation: LangGraphAdapter + +#### 1. Detection + +```python +# shekel/providers/langgraph.py + +from __future__ import annotations +from shekel.providers.base import ProviderAdapter + +class LangGraphAdapter(ProviderAdapter): + name = "langgraph" + + def install_patches(self, budget: Budget) -> None: + import langgraph.graph.state # raises ImportError if not installed + from langgraph.graph.state import StateGraph + _patch_state_graph(StateGraph, budget) + + def remove_patches(self) -> None: + from langgraph.graph.state import StateGraph + _unpatch_state_graph(StateGraph) +``` + +#### 2. Patching `StateGraph.add_node()` + +```python +_ORIGINAL_ADD_NODE = None + +def _patch_state_graph(StateGraph, budget: Budget) -> None: + global _ORIGINAL_ADD_NODE + if _ORIGINAL_ADD_NODE is not None: + return # already patched (reference counting handled by _patch.py pattern) + + _ORIGINAL_ADD_NODE = StateGraph.add_node + + def patched_add_node(self, node: str | type, fn=None, **kwargs): + node_name = node if isinstance(node, str) else node.__name__ + if fn is not None: + fn = _wrap_node(node_name, fn, budget) + elif callable(node): + # add_node("my_node") where node itself is callable + fn = _wrap_node(node_name, node, budget) + node = node_name + return _ORIGINAL_ADD_NODE(self, node, fn, **kwargs) + + StateGraph.add_node = patched_add_node +``` + +#### 3. Budget gate wrapper (sync + async) + +```python +import asyncio +import functools +from shekel._context import get_active_budget +from shekel.exceptions import NodeBudgetExceededError + +def _wrap_node(node_name: str, fn, budget: Budget): + """Wrap a LangGraph node function with a pre-execution budget gate.""" + + if asyncio.iscoroutinefunction(fn): + @functools.wraps(fn) + async def async_wrapper(state, *args, **kwargs): + _check_node_budget(node_name, budget) + return await fn(state, *args, **kwargs) + return async_wrapper + else: + @functools.wraps(fn) + def sync_wrapper(state, *args, **kwargs): + _check_node_budget(node_name, budget) + return fn(state, *args, **kwargs) + return sync_wrapper + + +def _check_node_budget(node_name: str, parent_budget: Budget) -> None: + """Check node-level and parent-level budget before node executes.""" + active = get_active_budget() + if active is None: + return + + # Check explicit node cap first + node_budget = active._node_budgets.get(node_name) + if node_budget is not None: + if node_budget.spent >= node_budget.max_usd: + raise NodeBudgetExceededError( + node_name=node_name, + spent=node_budget.spent, + limit=node_budget.max_usd, + ) + + # Check parent budget (implicit enforcement) + if active.spent >= active.max_usd: + raise NodeBudgetExceededError( + node_name=node_name, + spent=active.spent, + limit=active.max_usd, + ) +``` + +#### 4. `NodeBudgetExceededError` + +```python +# shekel/exceptions.py — addition + +class NodeBudgetExceededError(BudgetExceededError): + """Raised when a LangGraph node exceeds its budget cap. + + Raised *before* the node body executes — zero LLM spend wasted + when an explicit node cap is set. + """ + + def __init__( + self, + node_name: str, + spent: float, + limit: float, + ) -> None: + self.node_name = node_name + super().__init__(spent=spent, limit=limit, model=f"node:{node_name}") + + def __str__(self) -> str: + return ( + f"Node budget exceeded for '{self.node_name}' " + f"(${self.spent:.4f} / ${self.limit:.2f})\n" + f" Tip: Increase b.node('{self.node_name}', max_usd=...) " + f"or remove the explicit cap to use parent budget only." + ) +``` + +--- + +### Test Plan: `tests/test_langgraph_wrappers.py` + +Follow TDD — write all tests before implementing. + +#### Test group 1: Detection and patching + +```python +def test_langgraph_adapter_skipped_when_not_installed(monkeypatch): + """ShekelRuntime silently skips LangGraph if not installed.""" + monkeypatch.setitem(sys.modules, "langgraph", None) + with budget(max_usd=5.00): + pass # no error + +def test_langgraph_adapter_patches_add_node_on_budget_open(): + """add_node is patched when budget context is entered.""" + ... + +def test_langgraph_adapter_restores_add_node_on_budget_close(): + """add_node is restored after budget context exits.""" + ... + +def test_langgraph_adapter_reference_counted(): + """Nested budgets: patch applied once, removed when last budget closes.""" + ... +``` + +#### Test group 2: Implicit mode (attribution + parent enforcement) + +```python +def test_node_spend_attributed_to_parent_budget(): + """LLM spend inside a node rolls up to parent budget.""" + ... + +def test_parent_budget_circuit_breaks_during_node_execution(): + """Parent budget exhaustion raises NodeBudgetExceededError.""" + with budget(max_usd=0.01) as b: + graph = build_test_graph() # node makes LLM call > $0.01 + with pytest.raises(NodeBudgetExceededError) as exc: + graph.invoke({"input": "hello"}) + assert exc.value.node_name == "expensive_node" + +def test_tree_shows_per_node_attribution(): + """budget.tree() includes node-level spend breakdown.""" + ... + +def test_no_cap_nodes_share_parent_budget_uncapped(): + """Nodes without explicit cap are not individually limited.""" + ... +``` + +#### Test group 3: Explicit node caps + +```python +def test_explicit_node_cap_circuit_breaks_before_node_executes(): + """NodeBudgetExceededError raised before node body runs when cap exceeded.""" + with budget(max_usd=5.00) as b: + b.node("expensive_node", max_usd=0.01) + graph = build_test_graph() + # Pre-spend the node budget + b._node_budgets["expensive_node"]._record_spend(0.02, "gpt-4o", {}) + with pytest.raises(NodeBudgetExceededError) as exc: + graph.invoke({"input": "hello"}) + assert exc.value.node_name == "expensive_node" + assert exc.value.limit == 0.01 + +def test_explicit_cap_does_not_affect_other_nodes(): + """Capping one node does not restrict other nodes.""" + ... + +def test_unregistered_node_uses_parent_budget(): + """Node without explicit cap uses parent budget for enforcement.""" + ... +``` + +#### Test group 4: Async nodes + +```python +async def test_async_node_wrapped_correctly(): + """Async node functions are wrapped with async budget gate.""" + ... + +async def test_async_node_budget_exceeded_raises(): + """NodeBudgetExceededError raised correctly in async context.""" + ... +``` + +#### Test group 5: Exception contract + +```python +def test_node_budget_exceeded_is_subclass_of_budget_exceeded_error(): + """NodeBudgetExceededError is catchable as BudgetExceededError.""" + err = NodeBudgetExceededError("my_node", spent=0.05, limit=0.01) + assert isinstance(err, BudgetExceededError) + +def test_node_budget_exceeded_error_fields(): + """NodeBudgetExceededError carries node_name, spent, limit.""" + ... + +def test_existing_except_budget_exceeded_catches_node_error(): + """Existing user code catching BudgetExceededError still works.""" + with budget(max_usd=0.01): + with pytest.raises(BudgetExceededError): # not NodeBudgetExceededError + graph.invoke(...) +``` + +#### Test group 6: Edge cases + +```python +def test_looping_node_circuit_breaks_on_parent_budget(): + """A node in a loop is stopped when parent budget is exhausted.""" + # Build a graph with a loop (node → self-edge) + # Each iteration costs $0.001 + # Parent budget is $0.01 → should stop after ~10 iterations + ... + +def test_subgraph_node_wrapped_correctly(): + """Nodes that are themselves subgraphs (CompiledGraph) are wrapped.""" + ... + +def test_node_wrapped_only_once_with_nested_budgets(): + """Node function is not double-wrapped when nested budgets are used.""" + ... + +def test_budget_gate_overhead_under_1ms(): + """Pre-execution budget check completes in < 1ms.""" + ... +``` + +--- + +### Acceptance Criteria (Phase 1) + +- [ ] `StateGraph.add_node()` is transparently patched when LangGraph is installed and a budget is active +- [ ] Every node — sync and async — is wrapped with a pre-execution budget gate +- [ ] `NodeBudgetExceededError` is raised with correct `node_name`, `spent`, `limit` +- [ ] `NodeBudgetExceededError` is a subclass of `BudgetExceededError` — existing catch-all still works +- [ ] Implicit mode: parent budget enforces total, per-node attribution tracked +- [ ] Explicit mode: `b.node("name", max_usd=X)` creates a hard cap for that node +- [ ] Looping nodes are circuit-broken when parent budget exhausts +- [ ] Patch is reference-counted: nested budgets don't double-patch +- [ ] 100% code coverage +- [ ] All linters pass: `black`, `isort`, `ruff`, `mypy` + +--- + +## Phase 2 — CrewAI Layer (v0.3.3) + +**Delivers:** Agent-level and task-level circuit breaking for CrewAI. + +### What ships + +- `CrewAIAdapter` — registers as `BaseEventListener` on `crewai_event_bus` +- `AgentBudgetExceededError` — with `agent_name`, `spent`, `limit` +- `TaskBudgetExceededError` — with `task_name`, `spent`, `limit` +- Pre-task and pre-agent circuit breaking (before any LLM call is made) +- `b.agent("name", max_usd=X)` and `b.task("name", max_usd=X)` explicit API + +### Files to create + +- `shekel/providers/crewai.py` — `CrewAIAdapter` extending existing `crewai.py` +- `tests/test_crewai_wrappers.py` + +### Key design + +`CrewAIAdapter.install_patches()` registers `ShekelEventListener` on the global `crewai_event_bus`. On `TaskStartedEvent`, check task cap → raise `TaskBudgetExceededError` if exceeded. On `AgentExecutionStartedEvent`, check agent cap. On `LLMCallCompletedEvent`, record spend to agent + task child budgets. + +### Test groups (high-level) + +- Agent cap circuit-breaks before agent executes +- Task cap circuit-breaks before task's first LLM call +- Uncapped agents/tasks share parent budget +- Parallel crew tasks: correct `contextvars` budget lookup per task +- Exception hierarchy: both new exceptions are subclasses of `BudgetExceededError` +- Listener deregistered on budget close + +--- + +## Phase 3 — Loop Detection (v0.3.4) + +**Delivers:** Velocity-based circuit breaking. Catches runaway loops before absolute limit is hit. + +### What ships + +- Spend velocity tracker on `Budget` class +- `loop_detection_multiplier` and `loop_detection_window` parameters on `budget()` +- Circuit-break when instantaneous rate > N× rolling baseline +- Raises `BudgetExceededError` with message indicating loop detection triggered + +### Files to modify + +- `shekel/_budget.py` — add velocity tracking to `_record_spend()` +- `tests/test_loop_detection.py` + +### Key design + +Rolling window: store `(timestamp, amount)` tuples in a deque. On each `_record_spend()` call, evict entries older than `loop_detection_window` seconds, compute rate. If rate > `multiplier × baseline`, raise. + +Default: disabled (`loop_detection_multiplier=None`). Opt-in only. + +--- + +## Phase 4 — Tiered Thresholds (v0.3.5) + +**Delivers:** N-tier enforcement (warn / fallback / disable tools / stop) instead of binary warn + hard stop. + +### What ships + +- `tiers` parameter on `budget()` +- Actions: `"warn"`, `"fallback:"`, `"disable_tools"`, `"stop"` +- Backward compatible: existing `warn_at` and `fallback` still work as before + +### Files to modify + +- `shekel/_budget.py` — tier evaluation in `_check_limits()` +- `shekel/_run_config.py` — `tiers` parameter +- `tests/test_tiered_thresholds.py` + +--- + +## Phase 5 — OpenClaw Layer (v0.3.6) + +**Delivers:** Session-level circuit breaking for OpenClaw always-on agents. + +### What ships + +- `OpenClawAdapter` — hooks into `openclaw-sdk` agent lifecycle +- `SessionBudgetExceededError` — with `agent_name`, `spent`, `limit`, `window` +- Session suspension on budget breach (does not kill Gateway) +- `TemporalBudget` as default budget type for OpenClaw contexts + +### Files to create + +- `shekel/providers/openclaw.py` — `OpenClawAdapter` +- `tests/test_openclaw_wrappers.py` + +### Open question to resolve first + +Does `openclaw-sdk` expose Python lifecycle hooks for agent start/stop/suspend? If not, alternative: patch the agent's LLM call path and use TemporalBudget's rolling window as the enforcement mechanism. + +--- + +## Phase 6 — DX Layer (v0.3.7) + +**Delivers:** Adoption-friendly features — showback mode, budget tags, system-wide policy. + +### What ships + +- `mode="showback"` parameter on `budget()` — track everything, raise nothing +- `tags` parameter on `@tool()` decorator +- `budget.summary(group_by="tags")` output +- System-wide defaults via `shekel.configure(per_llm_call_max=0.10)` (or env vars) + +### Files to modify + +- `shekel/_budget.py` — `mode` parameter, skip raise in showback +- `shekel/_tool.py` — `tags` parameter on `@tool` +- `shekel/_run_config.py` — global defaults +- `tests/test_showback_mode.py`, `tests/test_budget_tags.py` + +--- + +## Cross-Cutting Concerns (All Phases) + +### Testing standards (from CLAUDE.md) + +- TDD: all tests written before implementation +- 100% coverage on all new code +- Test files named by domain: `test_langgraph_wrappers.py`, `test_crewai_wrappers.py`, etc. +- Run after each phase: `pytest --cov=shekel --cov-report=term-missing` + +### Linting (run after each phase) + +```bash +python -m black shekel/ tests/ +python -m isort shekel/ tests/ +python -m ruff check shekel/ tests/ +python -m mypy shekel/ +``` + +### Optional dependency pattern (existing convention) + +Each adapter wraps its import in `try/except ImportError` or raises `ImportError` from `install_patches()` — `ShekelRuntime` catches and silently skips. + +### Backward compatibility + +Every phase must pass the full existing test suite unchanged. No existing API is modified — only additive changes. + +--- + +## Delivery Summary + +| Phase | Version | Layer | Independent? | +|---|---|---|---| +| Phase 0 | v0.3.1 | ShekelRuntime + foundation API | Yes (prerequisite) | +| Phase 1 | v0.3.2 | LangGraph — node-level | Depends on Phase 0 | +| Phase 2 | v0.3.3 | CrewAI — agent + task level | Depends on Phase 0 | +| Phase 3 | v0.3.4 | Loop detection | Depends on Phase 0 | +| Phase 4 | v0.3.5 | Tiered thresholds | Depends on Phase 0 | +| Phase 5 | v0.3.6 | OpenClaw — session level | Depends on Phase 0 | +| Phase 6 | v0.3.7 | DX layer | Depends on Phase 0 | + +Phases 1–6 are independent of each other and can be developed in parallel after Phase 0.